170 lines
4.1 KiB
Go
170 lines
4.1 KiB
Go
package script
|
||
|
||
import (
|
||
"crypto/sha256"
|
||
"embed"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"maps"
|
||
"os"
|
||
"path/filepath"
|
||
"sync"
|
||
|
||
"github.com/infraboard/mcube/v2/ioc/config/log"
|
||
"github.com/rs/zerolog"
|
||
)
|
||
|
||
// 编译时嵌入的脚本 hash 文件
|
||
//
|
||
//go:embed hashes.json
|
||
var scriptHashesFS embed.FS
|
||
|
||
// ScriptIntegrityManager 脚本完整性管理器
|
||
// 脚本 hash 在编译时计算并硬编码到二进制中
|
||
// 在执行脚本前校验 hash,防止脚本被篡改
|
||
// 这确保了生产环境的脚本不会被篡改,即使脚本目录变为可写
|
||
type ScriptIntegrityManager struct {
|
||
log *zerolog.Logger
|
||
mu sync.RWMutex
|
||
|
||
// 注册的脚本 hash (相对路径 -> hash),从编译时常量加载
|
||
registeredHashes map[string]string
|
||
|
||
// 脚本根目录
|
||
scriptDir string
|
||
|
||
// 是否启用校验(编译时参数)
|
||
enabled bool
|
||
}
|
||
|
||
// NewScriptIntegrityManager 创建脚本完整性管理器
|
||
// 从编译时嵌入的 hash 文件加载脚本 hash
|
||
func NewScriptIntegrityManager(scriptDir string, enabled bool) *ScriptIntegrityManager {
|
||
m := &ScriptIntegrityManager{
|
||
log: log.Sub("script_integrity"),
|
||
registeredHashes: make(map[string]string),
|
||
scriptDir: scriptDir,
|
||
enabled: enabled,
|
||
}
|
||
|
||
// 如果启用了校验,加载编译时的 hash
|
||
if enabled {
|
||
if err := m.loadCompiledHashes(); err != nil {
|
||
m.log.Error().Err(err).Msg("加载编译时的脚本 hash 失败")
|
||
}
|
||
}
|
||
|
||
return m
|
||
}
|
||
|
||
// loadCompiledHashes 从编译时嵌入的 hash 文件加载脚本 hash
|
||
func (m *ScriptIntegrityManager) loadCompiledHashes() error {
|
||
data, err := scriptHashesFS.ReadFile("hashes.json")
|
||
if err != nil {
|
||
return fmt.Errorf("读取嵌入的 hash 文件失败: %v", err)
|
||
}
|
||
|
||
var hashes map[string]string
|
||
if err := json.Unmarshal(data, &hashes); err != nil {
|
||
return fmt.Errorf("解析 hash 文件失败: %v", err)
|
||
}
|
||
|
||
m.mu.Lock()
|
||
m.registeredHashes = hashes
|
||
m.mu.Unlock()
|
||
|
||
m.log.Info().Int("count", len(hashes)).Msg("加载编译时的脚本 hash 完成")
|
||
return nil
|
||
}
|
||
|
||
// Enable 启用校验
|
||
func (m *ScriptIntegrityManager) Enable() {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
m.enabled = true
|
||
}
|
||
|
||
// Disable 禁用校验
|
||
func (m *ScriptIntegrityManager) Disable() {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
m.enabled = false
|
||
}
|
||
|
||
// IsEnabled 是否启用校验
|
||
func (m *ScriptIntegrityManager) IsEnabled() bool {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
return m.enabled
|
||
}
|
||
|
||
// VerifyScript 校验脚本完整性
|
||
// 在执行脚本前调用,验证脚本是否被篡改
|
||
func (m *ScriptIntegrityManager) VerifyScript(scriptPath string) error {
|
||
if !m.enabled {
|
||
return nil
|
||
}
|
||
|
||
// 计算相对路径
|
||
relPath, err := filepath.Rel(m.scriptDir, scriptPath)
|
||
if err != nil {
|
||
return fmt.Errorf("计算相对路径失败: %v", err)
|
||
}
|
||
|
||
m.mu.RLock()
|
||
expectedHash, exists := m.registeredHashes[relPath]
|
||
m.mu.RUnlock()
|
||
|
||
if !exists {
|
||
return fmt.Errorf("脚本未注册: %s (可能是新增的脚本,请重启 Agent)", relPath)
|
||
}
|
||
|
||
// 计算当前 hash
|
||
currentHash, err := m.calculateFileHash(scriptPath)
|
||
if err != nil {
|
||
return fmt.Errorf("计算脚本 hash 失败: %v", err)
|
||
}
|
||
|
||
// 对比 hash
|
||
if currentHash != expectedHash {
|
||
m.log.Error().
|
||
Str("script", relPath).
|
||
Str("expected_hash", expectedHash).
|
||
Str("current_hash", currentHash).
|
||
Msg("脚本完整性校验失败:脚本可能被篡改")
|
||
return fmt.Errorf("脚本完整性校验失败: %s (hash 不匹配)", relPath)
|
||
}
|
||
|
||
m.log.Debug().Str("script", relPath).Str("hash", currentHash).Msg("脚本完整性校验通过")
|
||
return nil
|
||
}
|
||
|
||
// calculateFileHash 计算文件的 SHA256 hash
|
||
func (m *ScriptIntegrityManager) calculateFileHash(filePath string) (string, error) {
|
||
file, err := os.Open(filePath)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer file.Close()
|
||
|
||
hash := sha256.New()
|
||
if _, err := io.Copy(hash, file); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return hex.EncodeToString(hash.Sum(nil)), nil
|
||
}
|
||
|
||
// GetRegisteredScripts 获取已注册的脚本列表
|
||
func (m *ScriptIntegrityManager) GetRegisteredScripts() map[string]string {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
// 返回副本
|
||
result := make(map[string]string, len(m.registeredHashes))
|
||
maps.Copy(result, m.registeredHashes)
|
||
return result
|
||
}
|