Logo
热心市民王先生

关键代码验证

技术研究 人工智能 AI Agent

from abc import ABC, abstractmethod import pandas as pd from typing import List, Dict

数据抓取模块核心代码

数据源抽象层定义

# data_sources/base.py
from abc import ABC, abstractmethod
import pandas as pd
from typing import List, Dict

class DataSource(ABC):
    """数据源抽象基类"""

    def __init__(self, cache_manager=None):
        self.cache_manager = cache_manager
        self.session = self._init_session()

    @abstractmethod
    def _init_session(self):
        """初始化 HTTP 会话"""
        pass

    @abstractmethod
    def fetch_realtime_price(self, symbol: str, market: str = "A") -> Dict:
        """
        获取实时价格数据

        Args:
            symbol: 股票代码(如 '600519.SH')
            market: 市场类型 ('A', 'HK', 'US')

        Returns:
            包含实时价格数据的字典
        """
        pass

    @abstractmethod
    def fetch_historical_data(
        self,
        symbol: str,
        start_date: str,
        end_date: str,
        market: str = "A"
    ) -> pd.DataFrame:
        """
        获取历史行情数据

        Args:
            symbol: 股票代码
            start_date: 开始日期 (YYYY-MM-DD)
            end_date: 结束日期 (YYYY-MM-DD)
            market: 市场类型

        Returns:
            包含历史行情的 DataFrame
        """
        pass

    @abstractmethod
    def get_stock_list(self, market: str = "A") -> List[Dict]:
        """
        获取股票列表

        Args:
            market: 市场类型

        Returns:
            股票信息列表
        """
        pass

东方财富适配器实现

# data_sources/eastmoney_adapter.py
import requests
import pandas as pd
from datetime import datetime, timedelta
from typing import List, Dict
from .base import DataSource

class EastMoneyAdapter(DataSource):
    """东方财富数据源适配器"""

    BASE_URL = "http://push2.eastmoney.com/api/qt"

    def _init_session(self):
        """初始化 HTTP 会话"""
        session = requests.Session()
        session.headers.update({
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        })
        return session

    def fetch_realtime_price(self, symbol: str, market: str = "A") -> Dict:
        """
        获取 A 股实时价格

        Args:
            symbol: 股票代码(如 '600519.SH')

        Returns:
            实时价格数据字典
        """
        # 检查缓存
        cache_key = f"realtime:{symbol}"
        if self.cache_manager:
            cached_data = self.cache_manager.get(cache_key)
            if cached_data:
                return cached_data

        # 构造 API 请求
        code, market_suffix = self._parse_symbol(symbol)
        url = f"{self.BASE_URL}/stock/get"

        params = {
            'secid': f"{self._get_market_code(market_suffix)}.{code}",
            'fields': 'f43,f44,f45,f46,f47,f48,f49,f50,f51,f52',
            'ut': 'fa5fd1943c7b386f172d6893dbfba10b'
        }

        try:
            response = self.session.get(url, params=params, timeout=10)
            response.raise_for_status()
            data = response.json()

            if data.get('data') and data['data'].get('diff'):
                stock_data = data['data']['diff'][0]

                result = {
                    'symbol': symbol,
                    'price': stock_data.get('f43', 0),  # 最新价
                    'open': stock_data.get('f46', 0),     # 开盘价
                    'high': stock_data.get('f44', 0),     # 最高价
                    'low': stock_data.get('f45', 0),      # 最低价
                    'volume': stock_data.get('f47', 0),   # 成交量
                    'amount': stock_data.get('f48', 0),   # 成交额
                    'change': stock_data.get('f169', 0),  # 涨跌额
                    'change_percent': stock_data.get('f170', 0),  # 涨跌幅
                    'timestamp': datetime.now().isoformat()
                }

                # 写入缓存
                if self.cache_manager:
                    self.cache_manager.set(cache_key, result, ttl=60)

                return result
            else:
                return {'error': 'No data available'}

        except Exception as e:
            return {'error': str(e)}

    def fetch_historical_data(
        self,
        symbol: str,
        start_date: str,
        end_date: str,
        market: str = "A"
    ) -> pd.DataFrame:
        """
        获取历史行情数据

        Args:
            symbol: 股票代码
            start_date: 开始日期 (YYYY-MM-DD)
            end_date: 结束日期 (YYYY-MM-DD)
            market: 市场类型

        Returns:
            包含历史行情的 DataFrame
        """
        # 检查缓存
        cache_key = f"historical:{symbol}:{start_date}:{end_date}"
        if self.cache_manager:
            cached_data = self.cache_manager.get(cache_key)
            if cached_data is not None:
                return pd.DataFrame(cached_data)

        # 构造 API 请求
        code, market_suffix = self._parse_symbol(symbol)
        url = f"{self.BASE_URL}/stock/kline"

        # 转换日期格式
        start_dt = datetime.strptime(start_date, '%Y-%m-%d')
        end_dt = datetime.strptime(end_date, '%Y-%m-%d')
        start_timestamp = int(start_dt.timestamp())
        end_timestamp = int(end_dt.timestamp())

        params = {
            'secid': f"{self._get_market_code(market_suffix)}.{code}",
            'fields1': 'f1,f2,f3,f4,f5,f6',
            'fields2': 'f51,f52,f53,f54,f55,f56,f57,f58',
            'klt': '101',  # 日K线
            'fqt': '1',    # 前复权
            'beg': start_timestamp,
            'end': end_timestamp,
            'ut': 'fa5fd1943c7b386f172d6893dbfba10b'
        }

        try:
            response = self.session.get(url, params=params, timeout=10)
            response.raise_for_status()
            data = response.json()

            if data.get('data') and data['data'].get('klines'):
                klines = data['data']['klines']

                # 解析 K 线数据
                df_data = []
                for kline in klines:
                    parts = kline.split(',')
                    df_data.append({
                        'date': parts[0],
                        'open': float(parts[1]),
                        'close': float(parts[2]),
                        'high': float(parts[3]),
                        'low': float(parts[4]),
                        'volume': float(parts[5]),
                        'amount': float(parts[6])
                    })

                df = pd.DataFrame(df_data)
                df['date'] = pd.to_datetime(df['date'])
                df.set_index('date', inplace=True)

                # 写入缓存
                if self.cache_manager:
                    self.cache_manager.set(
                        cache_key,
                        df.to_dict('records'),
                        ttl=86400  # 1 天
                    )

                return df
            else:
                return pd.DataFrame()

        except Exception as e:
            print(f"Error fetching historical data for {symbol}: {e}")
            return pd.DataFrame()

    def get_stock_list(self, market: str = "A") -> List[Dict]:
        """获取股票列表"""
        # 简化实现,实际应用中需要完整的股票列表
        # 这里仅返回示例数据
        return [
            {'symbol': '600519.SH', 'name': '贵州茅台', 'market': 'A'},
            {'symbol': '000858.SZ', 'name': '五粮液', 'market': 'A'},
            {'symbol': '601318.SH', 'name': '中国平安', 'market': 'A'}
        ]

    def _parse_symbol(self, symbol: str) -> tuple:
        """解析股票代码"""
        if '.' in symbol:
            parts = symbol.split('.')
            return parts[0], parts[1]
        return symbol, 'SH'

    def _get_market_code(self, market_suffix: str) -> int:
        """获取市场代码"""
        market_map = {
            'SH': 1,  # 上交所
            'SZ': 0,  # 深交所
            'BJ': 2   # 北交所
        }
        return market_map.get(market_suffix, 1)

缓存管理器实现

# cache_manager.py
import json
import redis
from typing import Any, Optional

class CacheManager:
    """Redis 缓存管理器"""

    def __init__(self, redis_url: str = "redis://localhost:6379/0"):
        self.redis_client = redis.from_url(redis_url, decode_responses=True)

    def get(self, key: str) -> Optional[Any]:
        """
        从缓存获取数据

        Args:
            key: 缓存键

        Returns:
            缓存值,如果不存在返回 None
        """
        try:
            data = self.redis_client.get(key)
            if data:
                return json.loads(data)
            return None
        except Exception as e:
            print(f"Cache get error: {e}")
            return None

    def set(self, key: str, value: Any, ttl: int = 3600) -> bool:
        """
        设置缓存

        Args:
            key: 缓存键
            value: 缓存值
            ttl: 过期时间(秒)

        Returns:
            是否设置成功
        """
        try:
            data = json.dumps(value, ensure_ascii=False)
            return self.redis_client.setex(key, ttl, data)
        except Exception as e:
            print(f"Cache set error: {e}")
            return False

    def delete(self, key: str) -> bool:
        """
        删除缓存

        Args:
            key: 缓存键

        Returns:
            是否删除成功
        """
        try:
            return self.redis_client.delete(key) > 0
        except Exception as e:
            print(f"Cache delete error: {e}")
            return False

技术指标计算模块核心代码

指标计算引擎

# indicators/engine.py
import pandas as pd
import talib

class IndicatorEngine:
    """技术指标计算引擎"""

    def calculate_ma(
        self,
        data: pd.DataFrame,
        periods: list = [5, 10, 20, 60]
    ) -> dict:
        """
        计算移动平均线

        Args:
            data: 包含 OHLCV 数据的 DataFrame
            periods: MA 周期列表

        Returns:
            包含各周期 MA 的字典
        """
        result = {}
        close = data['close'].values

        for period in periods:
            if len(close) >= period:
                ma = talib.SMA(close, timeperiod=period)
                result[f'MA{period}'] = ma[-1]

        return result

    def calculate_macd(
        self,
        data: pd.DataFrame,
        fastperiod: int = 12,
        slowperiod: int = 26,
        signalperiod: int = 9
    ) -> dict:
        """
        计算 MACD 指标

        Args:
            data: 包含 OHLCV 数据的 DataFrame
            fastperiod: 快线周期
            slowperiod: 慢线周期
            signalperiod: 信号线周期

        Returns:
            包含 MACD 数据的字典
        """
        close = data['close'].values

        if len(close) < slowperiod:
            return {}

        dif, dea, macd = talib.MACD(
            close,
            fastperiod=fastperiod,
            slowperiod=slowperiod,
            signalperiod=signalperiod
        )

        return {
            'DIF': float(dif[-1]),
            'DEA': float(dea[-1]),
            'MACD': float(macd[-1])
        }

    def calculate_rsi(
        self,
        data: pd.DataFrame,
        period: int = 14
    ) -> float:
        """
        计算 RSI 指标

        Args:
            data: 包含 OHLCV 数据的 DataFrame
            period: RSI 周期

        Returns:
            RSI 值
        """
        close = data['close'].values

        if len(close) < period:
            return None

        rsi = talib.RSI(close, timeperiod=period)
        return float(rsi[-1])

    def calculate_kdj(
        self,
        data: pd.DataFrame,
        fastk_period: int = 9,
        slowk_period: int = 3,
        slowd_period: int = 3
    ) -> dict:
        """
        计算 KDJ 指标

        Args:
            data: 包含 OHLCV 数据的 DataFrame
            fastk_period: K 线周期
            slowk_period: K 值平滑周期
            slowd_period: D 值平滑周期

        Returns:
            包含 KDJ 数据的字典
        """
        high = data['high'].values
        low = data['low'].values
        close = data['close'].values

        if len(close) < fastk_period:
            return {}

        k, d = talib.STOCH(
            high,
            low,
            close,
            fastk_period=fastk_period,
            slowk_period=slowk_period,
            slowd_period=slowd_period
        )

        j = 3 * k[-1] - 2 * d[-1]

        return {
            'K': float(k[-1]),
            'D': float(d[-1]),
            'J': float(j)
        }

    def calculate_boll(
        self,
        data: pd.DataFrame,
        period: int = 20,
        nbdevup: int = 2,
        nbdevdn: int = 2
    ) -> dict:
        """
        计算布林带指标

        Args:
            data: 包含 OHLCV 数据的 DataFrame
            period: 周期
            nbdevup: 上轨标准差倍数
            nbdevdn: 下轨标准差倍数

        Returns:
            包含布林带数据的字典
        """
        close = data['close'].values

        if len(close) < period:
            return {}

        upper, middle, lower = talib.BBANDS(
            close,
            timeperiod=period,
            nbdevup=nbdevup,
            nbdevdn=nbdevdn
        )

        return {
            'upper': float(upper[-1]),
            'middle': float(middle[-1]),
            'lower': float(lower[-1])
        }

    def calculate_all_indicators(self, data: pd.DataFrame) -> dict:
        """
        计算所有技术指标

        Args:
            data: 包含 OHLCV 数据的 DataFrame

        Returns:
            包含所有技术指标的字典
        """
        result = {}

        # 移动平均线
        result.update(self.calculate_ma(data))

        # MACD
        result.update(self.calculate_macd(data))

        # RSI
        rsi = self.calculate_rsi(data)
        if rsi:
            result['RSI'] = rsi

        # KDJ
        result.update(self.calculate_kdj(data))

        # 布林带
        result.update(self.calculate_boll(data))

        return result

信号解读器

# indicators/interpreter.py
from typing import Dict

class SignalInterpreter:
    """技术指标信号解读器"""

    def interpret_ma(self, ma_data: Dict) -> str:
        """
        解读移动平均线信号

        Args:
            ma_data: 包含各周期 MA 的字典

        Returns:
            交易信号解读
        """
        ma5 = ma_data.get('MA5')
        ma10 = ma_data.get('MA10')
        ma20 = ma_data.get('MA20')

        if not ma5 or not ma10 or not ma20:
            return "数据不足,无法判断"

        signals = []

        # 短期均线 vs 中期均线
        if ma5 > ma10:
            signals.append("短期趋势向上")
        else:
            signals.append("短期趋势向下")

        # 中期均线 vs 长期均线
        if ma10 > ma20:
            signals.append("中期趋势向上")
        else:
            signals.append("中期趋势向下")

        return ",".join(signals)

    def interpret_macd(self, macd_data: Dict) -> str:
        """
        解读 MACD 信号

        Args:
            macd_data: 包含 MACD 数据的字典

        Returns:
            交易信号解读
        """
        dif = macd_data.get('DIF')
        dea = macd_data.get('DEA')
        macd = macd_data.get('MACD')

        if dif is None or dea is None or macd is None:
            return "数据不足,无法判断"

        if dif > dea and macd > 0:
            return "金叉向上,买入信号"
        elif dif < dea and macd < 0:
            return "死叉向下,卖出信号"
        elif dif > dea and macd < 0:
            return "金叉但柱状图为负,观望"
        elif dif < dea and macd > 0:
            return "死叉但柱状图为正,观望"
        else:
            return "震荡市,观望"

    def interpret_rsi(self, rsi: float) -> str:
        """
        解读 RSI 信号

        Args:
            rsi: RSI 值

        Returns:
            交易信号解读
        """
        if rsi is None:
            return "数据不足,无法判断"

        if rsi > 80:
            return "严重超买,考虑减仓"
        elif rsi > 70:
            return "超买,注意风险"
        elif rsi < 20:
            return "严重超卖,考虑抄底"
        elif rsi < 30:
            return "超卖,关注机会"
        else:
            return "处于正常区间,观望"

    def interpret_all(self, indicators: Dict) -> Dict[str, str]:
        """
        解读所有技术指标

        Args:
            indicators: 包含所有技术指标的字典

        Returns:
            包含各指标解读的字典
        """
        result = {}

        # 移动平均线
        ma_data = {k: v for k, v in indicators.items() if k.startswith('MA')}
        if ma_data:
            result['MA'] = self.interpret_ma(ma_data)

        # MACD
        if 'DIF' in indicators and 'DEA' in indicators:
            result['MACD'] = self.interpret_macd(indicators)

        # RSI
        if 'RSI' in indicators:
            result['RSI'] = self.interpret_rsi(indicators['RSI'])

        # KDJ
        if 'K' in indicators and 'D' in indicators and 'J' in indicators:
            result['KDJ'] = self.interpret_kdj(indicators)

        # 布林带
        if 'upper' in indicators and 'lower' in indicators:
            result['BOLL'] = self.interpret_boll(indicators)

        return result

    def interpret_kdj(self, kdj_data: Dict) -> str:
        """解读 KDJ 信号"""
        k = kdj_data.get('K')
        d = kdj_data.get('D')
        j = kdj_data.get('J')

        if k > d and j > 100:
            return "超买,注意风险"
        elif k < d and j < 0:
            return "超卖,关注机会"
        elif k > d:
            return "金叉,买入信号"
        elif k < d:
            return "死叉,卖出信号"
        else:
            return "震荡市,观望"

    def interpret_boll(self, boll_data: Dict) -> str:
        """解读布林带信号"""
        upper = boll_data.get('upper')
        lower = boll_data.get('lower')

        # 注意:这里需要实际的价格数据进行判断
        # 简化实现
        if upper and lower:
            return f"布林带区间:{lower:.2f} - {upper:.2f}"
        return "数据不足,无法判断"

大模型集成模块核心代码

模型客户端接口

# ai_models/base.py
from abc import ABC, abstractmethod
from typing import Dict

class ModelClient(ABC):
    """大模型客户端抽象基类"""

    def __init__(self, api_key: str, model: str):
        self.api_key = api_key
        self.model = model
        self.usage_stats = {
            'total_tokens': 0,
            'total_requests': 0,
            'failed_requests': 0
        }

    @abstractmethod
    def generate_report(self, prompt: str) -> str:
        """
        生成投资报告

        Args:
            prompt: 提示词

        Returns:
            生成的报告内容
        """
        pass

    @abstractmethod
    def get_usage_stats(self) -> Dict:
        """
        获取使用统计

        Returns:
            使用统计数据
        """
        pass

    def _update_stats(self, tokens: int, success: bool):
        """更新使用统计"""
        self.usage_stats['total_requests'] += 1
        if success:
            self.usage_stats['total_tokens'] += tokens
        else:
            self.usage_stats['failed_requests'] += 1

GLM-4.7 客户端

# ai_models/glm_client.py
import requests
import json
from typing import Dict
from .base import ModelClient

class GLM4Client(ModelClient):
    """GLM-4.7 模型客户端"""

    API_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"

    def generate_report(self, prompt: str) -> str:
        """
        使用 GLM-4.7 生成投资报告

        Args:
            prompt: 提示词

        Returns:
            生成的报告内容
        """
        headers = {
            'Authorization': f'Bearer {self.api_key}',
            'Content-Type': 'application/json'
        }

        data = {
            'model': self.model,
            'messages': [
                {
                    'role': 'user',
                    'content': prompt
                }
            ],
            'temperature': 0.7,
            'max_tokens': 4000
        }

        try:
            response = requests.post(
                self.API_URL,
                headers=headers,
                json=data,
                timeout=60
            )
            response.raise_for_status()

            result = response.json()

            # 提取生成的内容
            content = result['choices'][0]['message']['content']
            tokens = result.get('usage', {}).get('total_tokens', 0)

            self._update_stats(tokens, success=True)

            return content

        except Exception as e:
            print(f"GLM-4.7 API error: {e}")
            self._update_stats(0, success=False)
            raise

    def get_usage_stats(self) -> Dict:
        """获取使用统计"""
        return {
            'model': self.model,
            **self.usage_stats
        }

Prompt 模板引擎

# ai_models/prompt_engine.py
from jinja2 import Template
from typing import Dict, List

class PromptTemplateEngine:
    """Prompt 模板引擎"""

    DAILY_REPORT_TEMPLATE = """
你是一位专业的投资分析师,请基于以下数据生成投资日报。

## 市场概况
{{ market_overview }}

## 技术指标分析
{% for indicator_name, indicator_value in technical_indicators.items() %}
- {{ indicator_name }}: {{ indicator_value }}
{% endfor %}

## 重点股票数据
{% for stock in stocks %}
### {{ stock.name }} ({{ stock.symbol }})
- 最新价:{{ stock.price }}
- 涨跌幅:{{ stock.change_percent }}%
- 技术指标解读:
{% for ind_name, ind_value in stock.indicators.items() %}
  - {{ ind_name }}: {{ ind_value }}
{% endfor %}
{% endfor %}

## 要求
1. 分析至少 3 个行业的投资机会
2. 每个行业至少推荐 3 只股票(从上述股票中选择)
3. 提供明确的买入/卖出/持有建议
4. 给出详细的投资理由(基于技术指标和市场数据)
5. 添加风险提示和免责声明

请生成 Markdown 格式的报告,语言要专业、客观、逻辑严密。
"""

    def __init__(self):
        self.template = Template(self.DAILY_REPORT_TEMPLATE)

    def generate_prompt(
        self,
        market_overview: str,
        technical_indicators: Dict[str, float],
        stocks: List[Dict]
    ) -> str:
        """
        生成 Prompt

        Args:
            market_overview: 市场概况描述
            technical_indicators: 技术指标数据
            stocks: 股票数据列表

        Returns:
            生成的 Prompt
        """
        return self.template.render(
            market_overview=market_overview,
            technical_indicators=technical_indicators,
            stocks=stocks
        )

输出解析器

# ai_models/output_parser.py
import re
from typing import Dict, List

class OutputParser:
    """大模型输出解析器"""

    def parse_report(self, raw_output: str) -> Dict:
        """
        解析大模型生成的报告

        Args:
            raw_output: 大模型原始输出

        Returns:
            解析后的结构化报告
        """
        result = {
            'market_overview': '',
            'sectors': [],
            'risk_warnings': ''
        }

        # 解析市场概况
        market_overview_match = re.search(
            r'## 市场概况\n(.*?)(?=\n##|\n###|$)',
            raw_output,
            re.DOTALL
        )
        if market_overview_match:
            result['market_overview'] = market_overview_match.group(1).strip()

        # 解析行业分析
        sector_pattern = r'## \d+\.\s*([^\n]+)\n(.*?)(?=## \d+\.|## 风险提示|$)'
        sector_matches = re.findall(sector_pattern, raw_output, re.DOTALL)

        for sector_name, sector_content in sector_matches:
            sector_data = {
                'name': sector_name.strip(),
                'analysis': sector_content.strip(),
                'stocks': self._parse_stocks(sector_content)
            }
            result['sectors'].append(sector_data)

        # 解析风险提示
        risk_match = re.search(
            r'## 风险提示\n(.*?)$',
            raw_output,
            re.DOTALL
        )
        if risk_match:
            result['risk_warnings'] = risk_match.group(1).strip()

        return result

    def _parse_stocks(self, sector_content: str) -> List[Dict]:
        """
        解析行业内容中的股票信息

        Args:
            sector_content: 行业内容

        Returns:
            股票信息列表
        """
        stocks = []

        # 解析股票代码和名称
        stock_pattern = r'\*\*([A-Z0-9]+)\*\*|(([A-Z0-9]+))|`([A-Z0-9]+)`'
        stock_matches = re.findall(stock_pattern, sector_content)

        for match in stock_matches:
            code = match[0] or match[1] or match[2]
            if code and len(code) >= 4:
                # 提取推荐建议
                recommendation = self._extract_recommendation(sector_content, code)

                stocks.append({
                    'symbol': code,
                    'recommendation': recommendation
                })

        return stocks

    def _extract_recommendation(self, content: str, symbol: str) -> str:
        """
        提取股票的推荐建议

        Args:
            content: 内容
            symbol: 股票代码

        Returns:
            推荐建议
        """
        # 简化实现,实际需要更复杂的解析逻辑
        if '买入' in content and symbol in content:
            return '买入'
        elif '卖出' in content and symbol in content:
            return '卖出'
        else:
            return '持有'

推送模块核心代码

推送接口

# notifiers/base.py
from abc import ABC, abstractmethod
from typing import Dict, List

class Notifier(ABC):
    """推送接口抽象基类"""

    def __init__(self, config: Dict):
        self.config = config

    @abstractmethod
    def send_message(self, message: str, recipient: str) -> bool:
        """
        发送消息

        Args:
            message: 消息内容
            recipient: 接收者(可以是群组 ID、邮箱地址等)

        Returns:
            是否发送成功
        """
        pass

    @abstractmethod
    def format_message(self, content: Dict) -> str:
        """
        格式化消息内容

        Args:
            content: 报告内容字典

        Returns:
            格式化后的消息
        """
        pass

钉钉适配器

# notifiers/dingtalk_notifier.py
import requests
import json
from typing import Dict
from .base import Notifier

class DingTalkNotifier(Notifier):
    """钉钉机器人推送适配器"""

    def send_message(self, message: str, recipient: str) -> bool:
        """
        发送钉钉消息

        Args:
            message: 消息内容(Markdown 格式)
            recipient: Webhook URL

        Returns:
            是否发送成功
        """
        headers = {
            'Content-Type': 'application/json'
        }

        data = {
            'msgtype': 'markdown',
            'markdown': {
                'title': '投资日报',
                'text': message
            }
        }

        try:
            response = requests.post(
                recipient,
                headers=headers,
                json=data,
                timeout=10
            )
            response.raise_for_status()

            result = response.json()
            return result.get('errcode') == 0

        except Exception as e:
            print(f"DingTalk send error: {e}")
            return False

    def format_message(self, content: Dict) -> str:
        """
        格式化为钉钉 Markdown 消息

        Args:
            content: 报告内容

        Returns:
            Markdown 格式消息
        """
        message = f"# 📊 {content.get('date', '')} 投资日报\n\n"

        # 市场概况
        message += f"## 🌏 市场概况\n{content.get('market_overview', '')}\n\n"

        # 行业分析
        for sector in content.get('sectors', []):
            message += f"## 🏢 {sector.get('name', '')}\n"
            message += f"{sector.get('analysis', '')}\n\n"

            for stock in sector.get('stocks', []):
                message += f"### 📈 {stock.get('symbol', '')}\n"
                message += f"- 建议:{stock.get('recommendation', '')}\n"

        # 风险提示
        message += f"\n## ⚠️ 风险提示\n{content.get('risk_warnings', '')}\n\n"
        message += "---\n*本报告由 AI 生成,仅供参考,不构成投资建议*"

        return message

定时任务调度核心代码

调度器配置

# config/scheduler.yaml
scheduler:
  jobs:
    - name: a_share_daily_report
      trigger:
        type: cron
        cron_expression: "0 15 * * 1-5"
        timezone: "Asia/Shanghai"
      task: generate_report
      params:
        market: A-share
        sectors: ["科技", "金融", "消费"]
        channels: ["dingtalk", "email"]

    - name: us_market_daily_report
      trigger:
        type: cron
        cron_expression: "0 16 * * 1-5"
        timezone: "America/New_York"
      task: generate_report
      params:
        market: US
        sectors: ["科技", "医疗", "金融"]
        channels: ["telegram", "email"]

    - name: cache_cleanup
      trigger:
        type: interval
        interval_hours: 24
      task: cleanup_cache
      params:
        max_age_hours: 24

调度器实现

# scheduler/manager.py
import yaml
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from datetime import datetime
from typing import Dict

class SchedulerManager:
    """定时任务调度管理器"""

    def __init__(self, config_path: str, task_executor):
        """
        初始化调度器

        Args:
            config_path: 配置文件路径
            task_executor: 任务执行器
        """
        self.scheduler = BackgroundScheduler()
        self.task_executor = task_executor
        self.config = self._load_config(config_path)

    def _load_config(self, config_path: str) -> Dict:
        """加载调度器配置"""
        with open(config_path, 'r', encoding='utf-8') as f:
            return yaml.safe_load(f)

    def setup_jobs(self):
        """设置定时任务"""
        jobs = self.config.get('scheduler', {}).get('jobs', [])

        for job_config in jobs:
            job_name = job_config.get('name')
            trigger_config = job_config.get('trigger', {})
            task_name = job_config.get('task')
            params = job_config.get('params', {})

            # 创建触发器
            if trigger_config.get('type') == 'cron':
                trigger = CronTrigger.from_crontab(
                    trigger_config['cron_expression'],
                    timezone=trigger_config.get('timezone', 'Asia/Shanghai')
                )
            elif trigger_config.get('type') == 'interval':
                trigger = IntervalTrigger(
                    hours=trigger_config.get('interval_hours', 1)
                )
            else:
                print(f"Unknown trigger type for job {job_name}")
                continue

            # 添加任务
            self.scheduler.add_job(
                func=self._execute_task,
                trigger=trigger,
                id=job_name,
                name=job_name,
                kwargs={
                    'task_name': task_name,
                    'params': params
                }
            )

            print(f"Job {job_name} added to scheduler")

    def _execute_task(self, task_name: str, params: Dict):
        """
        执行任务

        Args:
            task_name: 任务名称
            params: 任务参数
        """
        print(f"[{datetime.now()}] Executing task: {task_name}")

        try:
            if task_name == 'generate_report':
                self.task_executor.generate_report(
                    market=params.get('market'),
                    sectors=params.get('sectors', []),
                    channels=params.get('channels', [])
                )
            elif task_name == 'cleanup_cache':
                self.task_executor.cleanup_cache(
                    max_age_hours=params.get('max_age_hours', 24)
                )
            else:
                print(f"Unknown task: {task_name}")

        except Exception as e:
            print(f"Task execution error: {e}")

    def start(self):
        """启动调度器"""
        self.setup_jobs()
        self.scheduler.start()
        print("Scheduler started")

    def stop(self):
        """停止调度器"""
        self.scheduler.shutdown()
        print("Scheduler stopped")

核心参考资料