关键代码验证
技术研究 人工智能 AI Agent
from abc import ABC, abstractmethod import pandas as pd from typing import List, Dict
数据抓取模块核心代码
数据源抽象层定义
# data_sources/base.py
from abc import ABC, abstractmethod
import pandas as pd
from typing import List, Dict
class DataSource(ABC):
"""数据源抽象基类"""
def __init__(self, cache_manager=None):
self.cache_manager = cache_manager
self.session = self._init_session()
@abstractmethod
def _init_session(self):
"""初始化 HTTP 会话"""
pass
@abstractmethod
def fetch_realtime_price(self, symbol: str, market: str = "A") -> Dict:
"""
获取实时价格数据
Args:
symbol: 股票代码(如 '600519.SH')
market: 市场类型 ('A', 'HK', 'US')
Returns:
包含实时价格数据的字典
"""
pass
@abstractmethod
def fetch_historical_data(
self,
symbol: str,
start_date: str,
end_date: str,
market: str = "A"
) -> pd.DataFrame:
"""
获取历史行情数据
Args:
symbol: 股票代码
start_date: 开始日期 (YYYY-MM-DD)
end_date: 结束日期 (YYYY-MM-DD)
market: 市场类型
Returns:
包含历史行情的 DataFrame
"""
pass
@abstractmethod
def get_stock_list(self, market: str = "A") -> List[Dict]:
"""
获取股票列表
Args:
market: 市场类型
Returns:
股票信息列表
"""
pass
东方财富适配器实现
# data_sources/eastmoney_adapter.py
import requests
import pandas as pd
from datetime import datetime, timedelta
from typing import List, Dict
from .base import DataSource
class EastMoneyAdapter(DataSource):
"""东方财富数据源适配器"""
BASE_URL = "http://push2.eastmoney.com/api/qt"
def _init_session(self):
"""初始化 HTTP 会话"""
session = requests.Session()
session.headers.update({
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
})
return session
def fetch_realtime_price(self, symbol: str, market: str = "A") -> Dict:
"""
获取 A 股实时价格
Args:
symbol: 股票代码(如 '600519.SH')
Returns:
实时价格数据字典
"""
# 检查缓存
cache_key = f"realtime:{symbol}"
if self.cache_manager:
cached_data = self.cache_manager.get(cache_key)
if cached_data:
return cached_data
# 构造 API 请求
code, market_suffix = self._parse_symbol(symbol)
url = f"{self.BASE_URL}/stock/get"
params = {
'secid': f"{self._get_market_code(market_suffix)}.{code}",
'fields': 'f43,f44,f45,f46,f47,f48,f49,f50,f51,f52',
'ut': 'fa5fd1943c7b386f172d6893dbfba10b'
}
try:
response = self.session.get(url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
if data.get('data') and data['data'].get('diff'):
stock_data = data['data']['diff'][0]
result = {
'symbol': symbol,
'price': stock_data.get('f43', 0), # 最新价
'open': stock_data.get('f46', 0), # 开盘价
'high': stock_data.get('f44', 0), # 最高价
'low': stock_data.get('f45', 0), # 最低价
'volume': stock_data.get('f47', 0), # 成交量
'amount': stock_data.get('f48', 0), # 成交额
'change': stock_data.get('f169', 0), # 涨跌额
'change_percent': stock_data.get('f170', 0), # 涨跌幅
'timestamp': datetime.now().isoformat()
}
# 写入缓存
if self.cache_manager:
self.cache_manager.set(cache_key, result, ttl=60)
return result
else:
return {'error': 'No data available'}
except Exception as e:
return {'error': str(e)}
def fetch_historical_data(
self,
symbol: str,
start_date: str,
end_date: str,
market: str = "A"
) -> pd.DataFrame:
"""
获取历史行情数据
Args:
symbol: 股票代码
start_date: 开始日期 (YYYY-MM-DD)
end_date: 结束日期 (YYYY-MM-DD)
market: 市场类型
Returns:
包含历史行情的 DataFrame
"""
# 检查缓存
cache_key = f"historical:{symbol}:{start_date}:{end_date}"
if self.cache_manager:
cached_data = self.cache_manager.get(cache_key)
if cached_data is not None:
return pd.DataFrame(cached_data)
# 构造 API 请求
code, market_suffix = self._parse_symbol(symbol)
url = f"{self.BASE_URL}/stock/kline"
# 转换日期格式
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
start_timestamp = int(start_dt.timestamp())
end_timestamp = int(end_dt.timestamp())
params = {
'secid': f"{self._get_market_code(market_suffix)}.{code}",
'fields1': 'f1,f2,f3,f4,f5,f6',
'fields2': 'f51,f52,f53,f54,f55,f56,f57,f58',
'klt': '101', # 日K线
'fqt': '1', # 前复权
'beg': start_timestamp,
'end': end_timestamp,
'ut': 'fa5fd1943c7b386f172d6893dbfba10b'
}
try:
response = self.session.get(url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
if data.get('data') and data['data'].get('klines'):
klines = data['data']['klines']
# 解析 K 线数据
df_data = []
for kline in klines:
parts = kline.split(',')
df_data.append({
'date': parts[0],
'open': float(parts[1]),
'close': float(parts[2]),
'high': float(parts[3]),
'low': float(parts[4]),
'volume': float(parts[5]),
'amount': float(parts[6])
})
df = pd.DataFrame(df_data)
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
# 写入缓存
if self.cache_manager:
self.cache_manager.set(
cache_key,
df.to_dict('records'),
ttl=86400 # 1 天
)
return df
else:
return pd.DataFrame()
except Exception as e:
print(f"Error fetching historical data for {symbol}: {e}")
return pd.DataFrame()
def get_stock_list(self, market: str = "A") -> List[Dict]:
"""获取股票列表"""
# 简化实现,实际应用中需要完整的股票列表
# 这里仅返回示例数据
return [
{'symbol': '600519.SH', 'name': '贵州茅台', 'market': 'A'},
{'symbol': '000858.SZ', 'name': '五粮液', 'market': 'A'},
{'symbol': '601318.SH', 'name': '中国平安', 'market': 'A'}
]
def _parse_symbol(self, symbol: str) -> tuple:
"""解析股票代码"""
if '.' in symbol:
parts = symbol.split('.')
return parts[0], parts[1]
return symbol, 'SH'
def _get_market_code(self, market_suffix: str) -> int:
"""获取市场代码"""
market_map = {
'SH': 1, # 上交所
'SZ': 0, # 深交所
'BJ': 2 # 北交所
}
return market_map.get(market_suffix, 1)
缓存管理器实现
# cache_manager.py
import json
import redis
from typing import Any, Optional
class CacheManager:
"""Redis 缓存管理器"""
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
self.redis_client = redis.from_url(redis_url, decode_responses=True)
def get(self, key: str) -> Optional[Any]:
"""
从缓存获取数据
Args:
key: 缓存键
Returns:
缓存值,如果不存在返回 None
"""
try:
data = self.redis_client.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
print(f"Cache get error: {e}")
return None
def set(self, key: str, value: Any, ttl: int = 3600) -> bool:
"""
设置缓存
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间(秒)
Returns:
是否设置成功
"""
try:
data = json.dumps(value, ensure_ascii=False)
return self.redis_client.setex(key, ttl, data)
except Exception as e:
print(f"Cache set error: {e}")
return False
def delete(self, key: str) -> bool:
"""
删除缓存
Args:
key: 缓存键
Returns:
是否删除成功
"""
try:
return self.redis_client.delete(key) > 0
except Exception as e:
print(f"Cache delete error: {e}")
return False
技术指标计算模块核心代码
指标计算引擎
# indicators/engine.py
import pandas as pd
import talib
class IndicatorEngine:
"""技术指标计算引擎"""
def calculate_ma(
self,
data: pd.DataFrame,
periods: list = [5, 10, 20, 60]
) -> dict:
"""
计算移动平均线
Args:
data: 包含 OHLCV 数据的 DataFrame
periods: MA 周期列表
Returns:
包含各周期 MA 的字典
"""
result = {}
close = data['close'].values
for period in periods:
if len(close) >= period:
ma = talib.SMA(close, timeperiod=period)
result[f'MA{period}'] = ma[-1]
return result
def calculate_macd(
self,
data: pd.DataFrame,
fastperiod: int = 12,
slowperiod: int = 26,
signalperiod: int = 9
) -> dict:
"""
计算 MACD 指标
Args:
data: 包含 OHLCV 数据的 DataFrame
fastperiod: 快线周期
slowperiod: 慢线周期
signalperiod: 信号线周期
Returns:
包含 MACD 数据的字典
"""
close = data['close'].values
if len(close) < slowperiod:
return {}
dif, dea, macd = talib.MACD(
close,
fastperiod=fastperiod,
slowperiod=slowperiod,
signalperiod=signalperiod
)
return {
'DIF': float(dif[-1]),
'DEA': float(dea[-1]),
'MACD': float(macd[-1])
}
def calculate_rsi(
self,
data: pd.DataFrame,
period: int = 14
) -> float:
"""
计算 RSI 指标
Args:
data: 包含 OHLCV 数据的 DataFrame
period: RSI 周期
Returns:
RSI 值
"""
close = data['close'].values
if len(close) < period:
return None
rsi = talib.RSI(close, timeperiod=period)
return float(rsi[-1])
def calculate_kdj(
self,
data: pd.DataFrame,
fastk_period: int = 9,
slowk_period: int = 3,
slowd_period: int = 3
) -> dict:
"""
计算 KDJ 指标
Args:
data: 包含 OHLCV 数据的 DataFrame
fastk_period: K 线周期
slowk_period: K 值平滑周期
slowd_period: D 值平滑周期
Returns:
包含 KDJ 数据的字典
"""
high = data['high'].values
low = data['low'].values
close = data['close'].values
if len(close) < fastk_period:
return {}
k, d = talib.STOCH(
high,
low,
close,
fastk_period=fastk_period,
slowk_period=slowk_period,
slowd_period=slowd_period
)
j = 3 * k[-1] - 2 * d[-1]
return {
'K': float(k[-1]),
'D': float(d[-1]),
'J': float(j)
}
def calculate_boll(
self,
data: pd.DataFrame,
period: int = 20,
nbdevup: int = 2,
nbdevdn: int = 2
) -> dict:
"""
计算布林带指标
Args:
data: 包含 OHLCV 数据的 DataFrame
period: 周期
nbdevup: 上轨标准差倍数
nbdevdn: 下轨标准差倍数
Returns:
包含布林带数据的字典
"""
close = data['close'].values
if len(close) < period:
return {}
upper, middle, lower = talib.BBANDS(
close,
timeperiod=period,
nbdevup=nbdevup,
nbdevdn=nbdevdn
)
return {
'upper': float(upper[-1]),
'middle': float(middle[-1]),
'lower': float(lower[-1])
}
def calculate_all_indicators(self, data: pd.DataFrame) -> dict:
"""
计算所有技术指标
Args:
data: 包含 OHLCV 数据的 DataFrame
Returns:
包含所有技术指标的字典
"""
result = {}
# 移动平均线
result.update(self.calculate_ma(data))
# MACD
result.update(self.calculate_macd(data))
# RSI
rsi = self.calculate_rsi(data)
if rsi:
result['RSI'] = rsi
# KDJ
result.update(self.calculate_kdj(data))
# 布林带
result.update(self.calculate_boll(data))
return result
信号解读器
# indicators/interpreter.py
from typing import Dict
class SignalInterpreter:
"""技术指标信号解读器"""
def interpret_ma(self, ma_data: Dict) -> str:
"""
解读移动平均线信号
Args:
ma_data: 包含各周期 MA 的字典
Returns:
交易信号解读
"""
ma5 = ma_data.get('MA5')
ma10 = ma_data.get('MA10')
ma20 = ma_data.get('MA20')
if not ma5 or not ma10 or not ma20:
return "数据不足,无法判断"
signals = []
# 短期均线 vs 中期均线
if ma5 > ma10:
signals.append("短期趋势向上")
else:
signals.append("短期趋势向下")
# 中期均线 vs 长期均线
if ma10 > ma20:
signals.append("中期趋势向上")
else:
signals.append("中期趋势向下")
return ",".join(signals)
def interpret_macd(self, macd_data: Dict) -> str:
"""
解读 MACD 信号
Args:
macd_data: 包含 MACD 数据的字典
Returns:
交易信号解读
"""
dif = macd_data.get('DIF')
dea = macd_data.get('DEA')
macd = macd_data.get('MACD')
if dif is None or dea is None or macd is None:
return "数据不足,无法判断"
if dif > dea and macd > 0:
return "金叉向上,买入信号"
elif dif < dea and macd < 0:
return "死叉向下,卖出信号"
elif dif > dea and macd < 0:
return "金叉但柱状图为负,观望"
elif dif < dea and macd > 0:
return "死叉但柱状图为正,观望"
else:
return "震荡市,观望"
def interpret_rsi(self, rsi: float) -> str:
"""
解读 RSI 信号
Args:
rsi: RSI 值
Returns:
交易信号解读
"""
if rsi is None:
return "数据不足,无法判断"
if rsi > 80:
return "严重超买,考虑减仓"
elif rsi > 70:
return "超买,注意风险"
elif rsi < 20:
return "严重超卖,考虑抄底"
elif rsi < 30:
return "超卖,关注机会"
else:
return "处于正常区间,观望"
def interpret_all(self, indicators: Dict) -> Dict[str, str]:
"""
解读所有技术指标
Args:
indicators: 包含所有技术指标的字典
Returns:
包含各指标解读的字典
"""
result = {}
# 移动平均线
ma_data = {k: v for k, v in indicators.items() if k.startswith('MA')}
if ma_data:
result['MA'] = self.interpret_ma(ma_data)
# MACD
if 'DIF' in indicators and 'DEA' in indicators:
result['MACD'] = self.interpret_macd(indicators)
# RSI
if 'RSI' in indicators:
result['RSI'] = self.interpret_rsi(indicators['RSI'])
# KDJ
if 'K' in indicators and 'D' in indicators and 'J' in indicators:
result['KDJ'] = self.interpret_kdj(indicators)
# 布林带
if 'upper' in indicators and 'lower' in indicators:
result['BOLL'] = self.interpret_boll(indicators)
return result
def interpret_kdj(self, kdj_data: Dict) -> str:
"""解读 KDJ 信号"""
k = kdj_data.get('K')
d = kdj_data.get('D')
j = kdj_data.get('J')
if k > d and j > 100:
return "超买,注意风险"
elif k < d and j < 0:
return "超卖,关注机会"
elif k > d:
return "金叉,买入信号"
elif k < d:
return "死叉,卖出信号"
else:
return "震荡市,观望"
def interpret_boll(self, boll_data: Dict) -> str:
"""解读布林带信号"""
upper = boll_data.get('upper')
lower = boll_data.get('lower')
# 注意:这里需要实际的价格数据进行判断
# 简化实现
if upper and lower:
return f"布林带区间:{lower:.2f} - {upper:.2f}"
return "数据不足,无法判断"
大模型集成模块核心代码
模型客户端接口
# ai_models/base.py
from abc import ABC, abstractmethod
from typing import Dict
class ModelClient(ABC):
"""大模型客户端抽象基类"""
def __init__(self, api_key: str, model: str):
self.api_key = api_key
self.model = model
self.usage_stats = {
'total_tokens': 0,
'total_requests': 0,
'failed_requests': 0
}
@abstractmethod
def generate_report(self, prompt: str) -> str:
"""
生成投资报告
Args:
prompt: 提示词
Returns:
生成的报告内容
"""
pass
@abstractmethod
def get_usage_stats(self) -> Dict:
"""
获取使用统计
Returns:
使用统计数据
"""
pass
def _update_stats(self, tokens: int, success: bool):
"""更新使用统计"""
self.usage_stats['total_requests'] += 1
if success:
self.usage_stats['total_tokens'] += tokens
else:
self.usage_stats['failed_requests'] += 1
GLM-4.7 客户端
# ai_models/glm_client.py
import requests
import json
from typing import Dict
from .base import ModelClient
class GLM4Client(ModelClient):
"""GLM-4.7 模型客户端"""
API_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
def generate_report(self, prompt: str) -> str:
"""
使用 GLM-4.7 生成投资报告
Args:
prompt: 提示词
Returns:
生成的报告内容
"""
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json'
}
data = {
'model': self.model,
'messages': [
{
'role': 'user',
'content': prompt
}
],
'temperature': 0.7,
'max_tokens': 4000
}
try:
response = requests.post(
self.API_URL,
headers=headers,
json=data,
timeout=60
)
response.raise_for_status()
result = response.json()
# 提取生成的内容
content = result['choices'][0]['message']['content']
tokens = result.get('usage', {}).get('total_tokens', 0)
self._update_stats(tokens, success=True)
return content
except Exception as e:
print(f"GLM-4.7 API error: {e}")
self._update_stats(0, success=False)
raise
def get_usage_stats(self) -> Dict:
"""获取使用统计"""
return {
'model': self.model,
**self.usage_stats
}
Prompt 模板引擎
# ai_models/prompt_engine.py
from jinja2 import Template
from typing import Dict, List
class PromptTemplateEngine:
"""Prompt 模板引擎"""
DAILY_REPORT_TEMPLATE = """
你是一位专业的投资分析师,请基于以下数据生成投资日报。
## 市场概况
{{ market_overview }}
## 技术指标分析
{% for indicator_name, indicator_value in technical_indicators.items() %}
- {{ indicator_name }}: {{ indicator_value }}
{% endfor %}
## 重点股票数据
{% for stock in stocks %}
### {{ stock.name }} ({{ stock.symbol }})
- 最新价:{{ stock.price }}
- 涨跌幅:{{ stock.change_percent }}%
- 技术指标解读:
{% for ind_name, ind_value in stock.indicators.items() %}
- {{ ind_name }}: {{ ind_value }}
{% endfor %}
{% endfor %}
## 要求
1. 分析至少 3 个行业的投资机会
2. 每个行业至少推荐 3 只股票(从上述股票中选择)
3. 提供明确的买入/卖出/持有建议
4. 给出详细的投资理由(基于技术指标和市场数据)
5. 添加风险提示和免责声明
请生成 Markdown 格式的报告,语言要专业、客观、逻辑严密。
"""
def __init__(self):
self.template = Template(self.DAILY_REPORT_TEMPLATE)
def generate_prompt(
self,
market_overview: str,
technical_indicators: Dict[str, float],
stocks: List[Dict]
) -> str:
"""
生成 Prompt
Args:
market_overview: 市场概况描述
technical_indicators: 技术指标数据
stocks: 股票数据列表
Returns:
生成的 Prompt
"""
return self.template.render(
market_overview=market_overview,
technical_indicators=technical_indicators,
stocks=stocks
)
输出解析器
# ai_models/output_parser.py
import re
from typing import Dict, List
class OutputParser:
"""大模型输出解析器"""
def parse_report(self, raw_output: str) -> Dict:
"""
解析大模型生成的报告
Args:
raw_output: 大模型原始输出
Returns:
解析后的结构化报告
"""
result = {
'market_overview': '',
'sectors': [],
'risk_warnings': ''
}
# 解析市场概况
market_overview_match = re.search(
r'## 市场概况\n(.*?)(?=\n##|\n###|$)',
raw_output,
re.DOTALL
)
if market_overview_match:
result['market_overview'] = market_overview_match.group(1).strip()
# 解析行业分析
sector_pattern = r'## \d+\.\s*([^\n]+)\n(.*?)(?=## \d+\.|## 风险提示|$)'
sector_matches = re.findall(sector_pattern, raw_output, re.DOTALL)
for sector_name, sector_content in sector_matches:
sector_data = {
'name': sector_name.strip(),
'analysis': sector_content.strip(),
'stocks': self._parse_stocks(sector_content)
}
result['sectors'].append(sector_data)
# 解析风险提示
risk_match = re.search(
r'## 风险提示\n(.*?)$',
raw_output,
re.DOTALL
)
if risk_match:
result['risk_warnings'] = risk_match.group(1).strip()
return result
def _parse_stocks(self, sector_content: str) -> List[Dict]:
"""
解析行业内容中的股票信息
Args:
sector_content: 行业内容
Returns:
股票信息列表
"""
stocks = []
# 解析股票代码和名称
stock_pattern = r'\*\*([A-Z0-9]+)\*\*|(([A-Z0-9]+))|`([A-Z0-9]+)`'
stock_matches = re.findall(stock_pattern, sector_content)
for match in stock_matches:
code = match[0] or match[1] or match[2]
if code and len(code) >= 4:
# 提取推荐建议
recommendation = self._extract_recommendation(sector_content, code)
stocks.append({
'symbol': code,
'recommendation': recommendation
})
return stocks
def _extract_recommendation(self, content: str, symbol: str) -> str:
"""
提取股票的推荐建议
Args:
content: 内容
symbol: 股票代码
Returns:
推荐建议
"""
# 简化实现,实际需要更复杂的解析逻辑
if '买入' in content and symbol in content:
return '买入'
elif '卖出' in content and symbol in content:
return '卖出'
else:
return '持有'
推送模块核心代码
推送接口
# notifiers/base.py
from abc import ABC, abstractmethod
from typing import Dict, List
class Notifier(ABC):
"""推送接口抽象基类"""
def __init__(self, config: Dict):
self.config = config
@abstractmethod
def send_message(self, message: str, recipient: str) -> bool:
"""
发送消息
Args:
message: 消息内容
recipient: 接收者(可以是群组 ID、邮箱地址等)
Returns:
是否发送成功
"""
pass
@abstractmethod
def format_message(self, content: Dict) -> str:
"""
格式化消息内容
Args:
content: 报告内容字典
Returns:
格式化后的消息
"""
pass
钉钉适配器
# notifiers/dingtalk_notifier.py
import requests
import json
from typing import Dict
from .base import Notifier
class DingTalkNotifier(Notifier):
"""钉钉机器人推送适配器"""
def send_message(self, message: str, recipient: str) -> bool:
"""
发送钉钉消息
Args:
message: 消息内容(Markdown 格式)
recipient: Webhook URL
Returns:
是否发送成功
"""
headers = {
'Content-Type': 'application/json'
}
data = {
'msgtype': 'markdown',
'markdown': {
'title': '投资日报',
'text': message
}
}
try:
response = requests.post(
recipient,
headers=headers,
json=data,
timeout=10
)
response.raise_for_status()
result = response.json()
return result.get('errcode') == 0
except Exception as e:
print(f"DingTalk send error: {e}")
return False
def format_message(self, content: Dict) -> str:
"""
格式化为钉钉 Markdown 消息
Args:
content: 报告内容
Returns:
Markdown 格式消息
"""
message = f"# 📊 {content.get('date', '')} 投资日报\n\n"
# 市场概况
message += f"## 🌏 市场概况\n{content.get('market_overview', '')}\n\n"
# 行业分析
for sector in content.get('sectors', []):
message += f"## 🏢 {sector.get('name', '')}\n"
message += f"{sector.get('analysis', '')}\n\n"
for stock in sector.get('stocks', []):
message += f"### 📈 {stock.get('symbol', '')}\n"
message += f"- 建议:{stock.get('recommendation', '')}\n"
# 风险提示
message += f"\n## ⚠️ 风险提示\n{content.get('risk_warnings', '')}\n\n"
message += "---\n*本报告由 AI 生成,仅供参考,不构成投资建议*"
return message
定时任务调度核心代码
调度器配置
# config/scheduler.yaml
scheduler:
jobs:
- name: a_share_daily_report
trigger:
type: cron
cron_expression: "0 15 * * 1-5"
timezone: "Asia/Shanghai"
task: generate_report
params:
market: A-share
sectors: ["科技", "金融", "消费"]
channels: ["dingtalk", "email"]
- name: us_market_daily_report
trigger:
type: cron
cron_expression: "0 16 * * 1-5"
timezone: "America/New_York"
task: generate_report
params:
market: US
sectors: ["科技", "医疗", "金融"]
channels: ["telegram", "email"]
- name: cache_cleanup
trigger:
type: interval
interval_hours: 24
task: cleanup_cache
params:
max_age_hours: 24
调度器实现
# scheduler/manager.py
import yaml
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from datetime import datetime
from typing import Dict
class SchedulerManager:
"""定时任务调度管理器"""
def __init__(self, config_path: str, task_executor):
"""
初始化调度器
Args:
config_path: 配置文件路径
task_executor: 任务执行器
"""
self.scheduler = BackgroundScheduler()
self.task_executor = task_executor
self.config = self._load_config(config_path)
def _load_config(self, config_path: str) -> Dict:
"""加载调度器配置"""
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def setup_jobs(self):
"""设置定时任务"""
jobs = self.config.get('scheduler', {}).get('jobs', [])
for job_config in jobs:
job_name = job_config.get('name')
trigger_config = job_config.get('trigger', {})
task_name = job_config.get('task')
params = job_config.get('params', {})
# 创建触发器
if trigger_config.get('type') == 'cron':
trigger = CronTrigger.from_crontab(
trigger_config['cron_expression'],
timezone=trigger_config.get('timezone', 'Asia/Shanghai')
)
elif trigger_config.get('type') == 'interval':
trigger = IntervalTrigger(
hours=trigger_config.get('interval_hours', 1)
)
else:
print(f"Unknown trigger type for job {job_name}")
continue
# 添加任务
self.scheduler.add_job(
func=self._execute_task,
trigger=trigger,
id=job_name,
name=job_name,
kwargs={
'task_name': task_name,
'params': params
}
)
print(f"Job {job_name} added to scheduler")
def _execute_task(self, task_name: str, params: Dict):
"""
执行任务
Args:
task_name: 任务名称
params: 任务参数
"""
print(f"[{datetime.now()}] Executing task: {task_name}")
try:
if task_name == 'generate_report':
self.task_executor.generate_report(
market=params.get('market'),
sectors=params.get('sectors', []),
channels=params.get('channels', [])
)
elif task_name == 'cleanup_cache':
self.task_executor.cleanup_cache(
max_age_hours=params.get('max_age_hours', 24)
)
else:
print(f"Unknown task: {task_name}")
except Exception as e:
print(f"Task execution error: {e}")
def start(self):
"""启动调度器"""
self.setup_jobs()
self.scheduler.start()
print("Scheduler started")
def stop(self):
"""停止调度器"""
self.scheduler.shutdown()
print("Scheduler stopped")
核心参考资料
- TA-Lib Python 接口文档 - 技术分析库
- Python requests 文档 - HTTP 库
- APScheduler 文档 - 定时任务调度
- Jinja2 文档 - 模板引擎
- Redis Python 客户端 - Redis 客户端