|
|
@@ -7,6 +7,7 @@ import (
|
|
|
"github.com/google/gopacket/layers"
|
|
|
"github.com/google/gopacket/pcap"
|
|
|
"os"
|
|
|
+ "regexp"
|
|
|
"strings"
|
|
|
"time"
|
|
|
)
|
|
|
@@ -41,23 +42,29 @@ func GetLoopbackInterface() (loopbackInterfaceName string, err error) {
|
|
|
return loopbackInterfaceName, nil
|
|
|
}
|
|
|
|
|
|
-func GetAllLoopbackInterface() (interfaceNames []string, err error) {
|
|
|
- interfaceNames = []string{}
|
|
|
+func GetAllLoopbackInterface() (interfaceNames map[string]string, err error) {
|
|
|
+ interfaceNames = map[string]string{}
|
|
|
// 获取所有网络接口
|
|
|
interfaces, err := pcap.FindAllDevs()
|
|
|
if err != nil {
|
|
|
glog.XWarning(fmt.Sprintf("pcap.FindAllDevs error : %v\n", err))
|
|
|
return interfaceNames, err
|
|
|
}
|
|
|
-
|
|
|
+ cp := regexp.MustCompile(`\d{1,3}.\d{1,3}.\d{1,3}.\d{1,3}`)
|
|
|
for _, face := range interfaces {
|
|
|
-
|
|
|
+ addr := ""
|
|
|
for _, address := range face.Addresses {
|
|
|
-
|
|
|
- println(face.Name + "-----------------" + address.IP.String())
|
|
|
-
|
|
|
+ ret := cp.FindString(address.IP.String())
|
|
|
+ if len(ret) > 1 {
|
|
|
+ addr = address.IP.String()
|
|
|
+ break
|
|
|
+ }
|
|
|
}
|
|
|
- interfaceNames = append(interfaceNames, face.Name)
|
|
|
+ if len(addr) < 1 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ //println(face.Description + "-----------------" + addr)
|
|
|
+ interfaceNames[face.Name] = addr
|
|
|
}
|
|
|
return interfaceNames, nil
|
|
|
}
|
|
|
@@ -72,7 +79,7 @@ func isLoopback(iface pcap.Interface) bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
-func Sniffer(interfaceName, sqlName string, port int) (err error) {
|
|
|
+func Sniffer(interfaceName, sqlName, ip string, port int) (err error) {
|
|
|
// 打开环回接口
|
|
|
handle, err := pcap.OpenLive(interfaceName, 1600, true, pcap.BlockForever)
|
|
|
if err != nil {
|
|
|
@@ -89,21 +96,22 @@ func Sniffer(interfaceName, sqlName string, port int) (err error) {
|
|
|
return err
|
|
|
}
|
|
|
glog.XWarning(fmt.Sprintf("Listening on %s\n", interfaceName))
|
|
|
+ ip = strings.ReplaceAll(ip, ".", "_")
|
|
|
+ filename := fmt.Sprintf("%v%v.txt", ip, sqlName)
|
|
|
|
|
|
- filename := fmt.Sprintf("%v%v.txt", sqlName, time.Now().Format("20060102030405"))
|
|
|
- fileHandle, err := os.Create(filename)
|
|
|
+ file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
|
if err != nil {
|
|
|
- glog.XWarning(fmt.Sprintf("os.Create %v error : %v\n", filename, err))
|
|
|
- return err
|
|
|
+ glog.XWarning(fmt.Sprintf("os.OpenFile %v error : %v\n", filename, err))
|
|
|
+ return
|
|
|
}
|
|
|
- defer fileHandle.Close()
|
|
|
+ defer file.Close()
|
|
|
|
|
|
// 创建数据包源
|
|
|
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
|
|
for packet := range packetSource.Packets() {
|
|
|
sqlStatement, err := processPacket(packet)
|
|
|
- if err == nil {
|
|
|
- fileHandle.Write([]byte(sqlStatement + "\n"))
|
|
|
+ if err == nil && len(sqlStatement) > 2 {
|
|
|
+ file.WriteString(time.Now().Format("2006-01-02 03:04:05") + " " + sqlStatement + "\n")
|
|
|
}
|
|
|
}
|
|
|
return nil
|
|
|
@@ -144,6 +152,7 @@ func extractSQLFromPayload(payload []byte) string {
|
|
|
data = strings.ReplaceAll(data, "\n", "")
|
|
|
data = strings.ReplaceAll(data, "\r\n", "")
|
|
|
data = strings.ReplaceAll(data, "\t", "")
|
|
|
+
|
|
|
// 检查是否包含 SQL 关键字
|
|
|
if containsSQLKeyword(data) {
|
|
|
return data
|