|
|
@@ -0,0 +1,162 @@
|
|
|
+package network
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "gbase/glog"
|
|
|
+ "github.com/google/gopacket"
|
|
|
+ "github.com/google/gopacket/layers"
|
|
|
+ "github.com/google/gopacket/pcap"
|
|
|
+ "os"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+func GetLoopbackInterface() (loopbackInterfaceName string, err error) {
|
|
|
+
|
|
|
+ // 获取所有网络接口
|
|
|
+ interfaces, err := pcap.FindAllDevs()
|
|
|
+ if err != nil {
|
|
|
+ glog.XWarning(fmt.Sprintf("pcap.FindAllDevs error : %v\n", err))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ found := false
|
|
|
+
|
|
|
+ // 遍历所有接口,找到环回接口
|
|
|
+ for _, iface := range interfaces {
|
|
|
+
|
|
|
+ if isLoopback(iface) {
|
|
|
+ //fmt.Printf("Name: %s, Description: %s\n", iface.Name, iface.Description)
|
|
|
+ loopbackInterfaceName = iface.Name
|
|
|
+ found = true
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if !found {
|
|
|
+ glog.XWarning("No loopback interface found")
|
|
|
+ return loopbackInterfaceName, fmt.Errorf("No loopback interface found")
|
|
|
+ }
|
|
|
+
|
|
|
+ return loopbackInterfaceName, nil
|
|
|
+}
|
|
|
+
|
|
|
+func GetAllLoopbackInterface() (interfaceNames []string, err error) {
|
|
|
+ interfaceNames = []string{}
|
|
|
+ // 获取所有网络接口
|
|
|
+ interfaces, err := pcap.FindAllDevs()
|
|
|
+ if err != nil {
|
|
|
+ glog.XWarning(fmt.Sprintf("pcap.FindAllDevs error : %v\n", err))
|
|
|
+ return interfaceNames, err
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, face := range interfaces {
|
|
|
+
|
|
|
+ for _, address := range face.Addresses {
|
|
|
+
|
|
|
+ println(face.Name + "-----------------" + address.IP.String())
|
|
|
+
|
|
|
+ }
|
|
|
+ interfaceNames = append(interfaceNames, face.Name)
|
|
|
+ }
|
|
|
+ return interfaceNames, nil
|
|
|
+}
|
|
|
+
|
|
|
+// 判断接口是否是环回接口
|
|
|
+func isLoopback(iface pcap.Interface) bool {
|
|
|
+ for _, address := range iface.Addresses {
|
|
|
+ if address.IP.IsLoopback() {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func Sniffer(interfaceName, sqlName string, port int) (err error) {
|
|
|
+ // 打开环回接口
|
|
|
+ handle, err := pcap.OpenLive(interfaceName, 1600, true, pcap.BlockForever)
|
|
|
+ if err != nil {
|
|
|
+ glog.XWarning(fmt.Sprintf("pcap.OpenLive %v error : %v\n", interfaceName, err))
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer handle.Close()
|
|
|
+
|
|
|
+ // 设置过滤器,只捕获 TCP 1433 端口(SQL Server 端口)的数据包
|
|
|
+ filter := fmt.Sprintf("tcp and port %v", port)
|
|
|
+ err = handle.SetBPFFilter(filter)
|
|
|
+ if err != nil {
|
|
|
+ glog.XWarning(fmt.Sprintf("handle.SetBPFFilter error : %v\n", err))
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ glog.XWarning(fmt.Sprintf("Listening on %s\n", interfaceName))
|
|
|
+
|
|
|
+ filename := fmt.Sprintf("%v%v.txt", sqlName, time.Now().Format("20060102030405"))
|
|
|
+ fileHandle, err := os.Create(filename)
|
|
|
+ if err != nil {
|
|
|
+ glog.XWarning(fmt.Sprintf("os.Create %v error : %v\n", filename, err))
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer fileHandle.Close()
|
|
|
+
|
|
|
+ // 创建数据包源
|
|
|
+ packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
|
|
+ for packet := range packetSource.Packets() {
|
|
|
+ sqlStatement, err := processPacket(packet)
|
|
|
+ if err == nil {
|
|
|
+ fileHandle.Write([]byte(sqlStatement + "\n"))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+func processPacket(packet gopacket.Packet) (sqlStatement string, err error) {
|
|
|
+ // 检测是否存在任何错误
|
|
|
+ errs := packet.ErrorLayer()
|
|
|
+ if err != nil {
|
|
|
+ glog.XWarning(fmt.Sprintf("decoding packet error : %v\n", errs.Error()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 解析 TCP 层
|
|
|
+ tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
|
|
+ if tcpLayer == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 打印应用层/有效载荷
|
|
|
+ applicationLayer := packet.ApplicationLayer()
|
|
|
+ if applicationLayer == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ payload := applicationLayer.Payload()
|
|
|
+ sqlStatement = extractSQLFromPayload(payload)
|
|
|
+ if sqlStatement != "" {
|
|
|
+ return sqlStatement, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return sqlStatement, nil
|
|
|
+}
|
|
|
+
|
|
|
+func extractSQLFromPayload(payload []byte) string {
|
|
|
+ // 将字节转换为字符串
|
|
|
+ data := string(payload)
|
|
|
+ data = strings.ReplaceAll(data, "\r", "")
|
|
|
+ data = strings.ReplaceAll(data, "\n", "")
|
|
|
+ data = strings.ReplaceAll(data, "\r\n", "")
|
|
|
+ data = strings.ReplaceAll(data, "\t", "")
|
|
|
+ // 检查是否包含 SQL 关键字
|
|
|
+ if containsSQLKeyword(data) {
|
|
|
+ return data
|
|
|
+ }
|
|
|
+ return ""
|
|
|
+}
|
|
|
+
|
|
|
+func containsSQLKeyword(data string) bool {
|
|
|
+ keywords := []string{"INSERT", "UPDATE", "DELETE", "SELECT"}
|
|
|
+ for _, keyword := range keywords {
|
|
|
+ if strings.Contains(strings.ToUpper(data), keyword) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|