From 98c6a72f96e843bd1cb4dfca9ee56b1a996f2c67 Mon Sep 17 00:00:00 2001
From: jjxu <428192774@qq.com>
Date: Sun, 15 Jun 2025 19:56:18 +0800
Subject: [PATCH] =?UTF-8?q?fix:=20=E5=AE=8C=E6=88=90=E7=94=A8=E6=88=B7?=
 =?UTF-8?q?=E8=BF=9B=E5=85=A5=E8=81=8A=E5=A4=A9=E5=AE=A4=E7=9B=B4=E6=8E=A5?=
 =?UTF-8?q?=E5=9B=9E=E5=A4=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 pkg/common/ws/chatRoom.go                 | 24 ++++++------
 pkg/common/ws/consts.go                   |  3 +-
 pkg/service/asChat/logic/chat.go          |  2 +-
 pkg/service/asChat/robot/replyAndRuler.go | 48 ++++++++++++++---------
 pkg/service/asChat/robot/robot.go         | 10 ++---
 pkg/service/asChat/robot/task.go          |  2 +-
 6 files changed, 51 insertions(+), 38 deletions(-)

diff --git a/pkg/common/ws/chatRoom.go b/pkg/common/ws/chatRoom.go
index 1170640..d69b90b 100644
--- a/pkg/common/ws/chatRoom.go
+++ b/pkg/common/ws/chatRoom.go
@@ -9,6 +9,7 @@ package ws
 import (
 	"encoding/json"
 	"fmt"
+	"fonchain-fiee/api/accountFiee"
 	"fonchain-fiee/pkg/utils"
 	"github.com/gorilla/websocket"
 	"go.uber.org/zap"
@@ -95,7 +96,7 @@ func (o *ChatRoom) Run() {
 		select {
 		// 注册事件
 		case newClient := <-o.register:
-			o.pushEvent(EventUserJoin, EventProgressBefore, newClient)
+			o.pushEvent(EventUserJoin, EventProgressBefore, nil, newClient)
 
 			o.clientsRwLocker.Lock()
 			//添加到客户端集合中
@@ -138,11 +139,11 @@ func (o *ChatRoom) Run() {
 				//再把自己的客户端加入会话
 				o.Session[newClient.SessionId] = append(o.Session[newClient.SessionId], newClient)
 			}
-			o.pushEvent(EventUserJoin, EventProgressAfter, newClient)
+			o.pushEvent(EventUserJoin, EventProgressAfter, nil, newClient)
 			o.clientsRwLocker.Unlock()
 		//注销事件
 		case client := <-o.UnRegister:
-			o.pushEvent(EventUserLeave, EventProgressBefore, client)
+			o.pushEvent(EventUserLeave, EventProgressBefore, nil, client)
 			//panic 恢复
 			defer func() {
 				if r := recover(); r != "" {
@@ -167,7 +168,7 @@ func (o *ChatRoom) Run() {
 				delete(o.clients[client.UserId], client.ClientId)
 				fmt.Printf("ws客户端%s 被注销\n", client.ClientId)
 			}
-			o.pushEvent(EventUserLeave, EventProgressAfter, client)
+			o.pushEvent(EventUserLeave, EventProgressAfter, nil, client)
 		// 消息群发事件
 		case messageInfo := <-o.broadcast:
 			o.Broadcast(messageInfo.message, messageInfo.UserIds...)
@@ -189,13 +190,9 @@ func (o *ChatRoom) Register(c *Client) (sessionId string) {
 // sessionId: 会话id
 // msgType: 消息类型
 // message: 消息内容
-func (o *ChatRoom) SendSessionMessage(sendUserId int64, sessionId string, msgType WsType, message any) (userIdInSession []int64, err error) {
+func (o *ChatRoom) SendSessionMessage(chatUser *accountFiee.ChatUserData, sessionId string, msgType WsType, message any) (userIdInSession []int64, err error) {
 	o.clientsRwLocker.Lock()
 	defer o.clientsRwLocker.Unlock()
-
-	o.pushEvent(EventChatMessage, EventProgressBefore, sendUserId, sessionId, msgType, message)
-	defer o.pushEvent(EventChatMessage, EventProgressAfter, sendUserId, sessionId, msgType, message)
-
 	var msg = WsSessionInfo{
 		Type:    msgType,
 		Content: message,
@@ -213,11 +210,13 @@ func (o *ChatRoom) SendSessionMessage(sendUserId int64, sessionId string, msgTyp
 			_, exist := o.clients[client.UserId][client.ClientId]
 			if exist {
 				usableClients = append(usableClients, o.Session[sessionId][i])
+				o.pushEvent(EventChatMessage, EventProgressBefore, chatUser, o.Session[sessionId][i], message)
 			}
 		}
 		fmt.Printf("client:%+v\n", client)
-		if client != nil && client.UserId != sendUserId {
+		if client != nil && client.UserId != chatUser.ID {
 			client.Send <- msgBytes
+			o.pushEvent(EventChatMessage, EventProgressAfter, chatUser, o.Session[sessionId][i], message)
 			userIdInSession = append(userIdInSession, client.UserId)
 		}
 		//client.Send <- msgBytes
@@ -343,7 +342,7 @@ func (o *ChatRoom) UnRegisterEventListener(listenerChan *EventListener) {
 }
 
 // pushEvent 推送聊天室事件
-func (o *ChatRoom) pushEvent(eventType EventType, progress EventProgress, data ...any) {
+func (o *ChatRoom) pushEvent(eventType EventType, progress EventProgress, chatUser *accountFiee.ChatUserData, client *Client, data ...any) {
 	o.EventRwLocker.Lock()
 	defer o.EventRwLocker.Unlock()
 	for _, listener := range o.eventBus {
@@ -362,7 +361,8 @@ func (o *ChatRoom) pushEvent(eventType EventType, progress EventProgress, data .
 				EventType:    eventType,
 				ProgressType: progress,
 			},
-			Data: data,
+			Client: client,
+			Data:   data,
 		}
 	}
 }
diff --git a/pkg/common/ws/consts.go b/pkg/common/ws/consts.go
index 263d128..5b98685 100644
--- a/pkg/common/ws/consts.go
+++ b/pkg/common/ws/consts.go
@@ -42,7 +42,8 @@ type ListenEvent struct {
 }
 type ListenEventData struct {
 	ListenEvent
-	Data interface{}
+	Client *Client
+	Data   any
 }
 type ListenEventChan chan ListenEventData
 type EventListener struct {
diff --git a/pkg/service/asChat/logic/chat.go b/pkg/service/asChat/logic/chat.go
index 2b7009b..d88cdb8 100644
--- a/pkg/service/asChat/logic/chat.go
+++ b/pkg/service/asChat/logic/chat.go
@@ -77,7 +77,7 @@ func NewMessage(ctx context.Context, cache *chatCache.ChatCache, chatUser *accou
 	var notice = dto.MessageListType{}
 	notice.BuildMessage(resp.Data)
 	fmt.Printf("ws消息提醒:%+v\n", notice)
-	_, err = consts.ChatRoom.SendSessionMessage(chatUser.ID, request.SessionId, ws.NewChatMsgType, notice)
+	_, err = consts.ChatRoom.SendSessionMessage(chatUser, request.SessionId, ws.NewChatMsgType, notice)
 	if err != nil {
 		log.Print("发送新消息通知失败", zap.Error(err), zap.Any("notice", notice))
 	}
diff --git a/pkg/service/asChat/robot/replyAndRuler.go b/pkg/service/asChat/robot/replyAndRuler.go
index b245b22..bccb4a5 100644
--- a/pkg/service/asChat/robot/replyAndRuler.go
+++ b/pkg/service/asChat/robot/replyAndRuler.go
@@ -22,9 +22,9 @@ type Reply struct {
 	Rules    []IRule
 }
 
-func (r *Reply) Hit(eventType ws.EventType, msg *accountFiee.ChatRecordData, robotId int64) (hit bool, runTime time.Time, logic func() error) {
+func (r *Reply) Hit(eventType ws.EventType, chatUser *accountFiee.ChatUserData, wsClient *ws.Client, msg *accountFiee.ChatRecordData, robotInfo *accountFiee.ChatUserData) (hit bool, runTime time.Time, logic func(msg string) error) {
 	for _, rule := range r.Rules {
-		hit, runTime, logic = rule.Hit(eventType, msg, robotId)
+		hit, runTime, logic = rule.Hit(eventType, chatUser, wsClient, msg, robotInfo)
 		if hit {
 			return
 		}
@@ -33,7 +33,7 @@ func (r *Reply) Hit(eventType ws.EventType, msg *accountFiee.ChatRecordData, rob
 }
 
 type IRule interface {
-	Hit(eventType ws.EventType, msg *accountFiee.ChatRecordData, robotId int64) (hit bool, runTime time.Time, logic func() error)
+	Hit(eventType ws.EventType, chatUser *accountFiee.ChatUserData, wsClient *ws.Client, msg *accountFiee.ChatRecordData, robotInfo *accountFiee.ChatUserData) (hit bool, runTime time.Time, logic func(msg string) error)
 }
 
 func NewReplyWhenHitKeywords(keywords []string) IRule {
@@ -45,17 +45,20 @@ type ReplyWhenHitKeywords struct {
 	Keywords []string `json:"keywords"`
 }
 
-func (k ReplyWhenHitKeywords) Hit(eventType ws.EventType, record *accountFiee.ChatRecordData, robotId int64) (hit bool, runTime time.Time, logic func() error) {
+func (k ReplyWhenHitKeywords) Hit(eventType ws.EventType, chatUser *accountFiee.ChatUserData, wsClient *ws.Client, record *accountFiee.ChatRecordData, robotInfo *accountFiee.ChatUserData) (hit bool, runTime time.Time, logic func(msg string) error) {
+	if record == nil {
+		return
+	}
 	for _, v := range k.Keywords {
 		if strings.Contains(record.Content, v) {
 			hit = true
 			break
 		}
 	}
-	logic = func() error {
+	logic = func(msg string) error {
 		var notice = dto.MessageListType{}
 		notice.BuildMessage(record)
-		_, err := consts.ChatRoom.SendSessionMessage(robotId, record.SessionId, ws.NewChatMsgType, notice)
+		_, err := consts.ChatRoom.SendSessionMessage(robotInfo, record.SessionId, ws.NewChatMsgType, notice)
 		return err
 	}
 	return
@@ -69,16 +72,17 @@ func NewReplyWhenUserJoinSession() IRule {
 type ReplyWhenUserJoinSession struct {
 }
 
-func (k ReplyWhenUserJoinSession) Hit(eventType ws.EventType, record *accountFiee.ChatRecordData, robotId int64) (hit bool, runTime time.Time, logic func() error) {
+func (k ReplyWhenUserJoinSession) Hit(eventType ws.EventType, chatUser *accountFiee.ChatUserData, wsClient *ws.Client, record *accountFiee.ChatRecordData, robotInfo *accountFiee.ChatUserData) (hit bool, runTime time.Time, logic func(msg string) error) {
 	if eventType != ws.EventUserJoin {
 		return
 	}
-	if record == nil {
+	if wsClient == nil {
 		return
 	}
-	queryRes, err := service.AccountFieeProvider.GetChatRecordList(context.Background(), &accountFiee.GetChatRecordListRequest{
+	ctx := context.Background()
+	queryRes, err := service.AccountFieeProvider.GetChatRecordList(ctx, &accountFiee.GetChatRecordListRequest{
 		Query: &accountFiee.ChatRecordData{
-			SessionId: record.SessionId,
+			SessionId: wsClient.SessionId,
 		},
 		Page:     1,
 		PageSize: 1,
@@ -90,18 +94,26 @@ func (k ReplyWhenUserJoinSession) Hit(eventType ws.EventType, record *accountFie
 	//如果最近一次的消息也是机器人发送的,就不再发送了
 	for i, v := range queryRes.List {
 		if i == 0 {
-			if v.UserId == robotId {
+			if v.UserId == robotInfo.ID {
 				return
 			} else {
 				break
 			}
 		}
 	}
-
-	logic = func() error {
+	hit = true
+	logic = func(msg string) error {
 		var notice = dto.MessageListType{}
-		notice.BuildMessage(record)
-		_, err = consts.ChatRoom.SendSessionMessage(robotId, record.SessionId, ws.NewChatMsgType, notice)
+		newRecord := &accountFiee.ChatRecordData{
+			SessionId: wsClient.SessionId,
+			UserId:    wsClient.UserId,
+			Name:      wsClient.SessionId,
+			Avatar:    robotInfo.Avatar,
+			MsgType:   1,
+			Content:   msg,
+		}
+		notice.BuildMessage(newRecord)
+		_, err = consts.ChatRoom.SendSessionMessage(robotInfo, wsClient.SessionId, ws.NewChatMsgType, notice)
 		return err
 	}
 	return
@@ -118,12 +130,12 @@ type ReplyWhenWaiterNoAction struct {
 	DelaySecond time.Duration
 }
 
-func (k *ReplyWhenWaiterNoAction) Hit(eventType ws.EventType, record *accountFiee.ChatRecordData, robotId int64) (hit bool, runTime time.Time, logic func() error) {
+func (k *ReplyWhenWaiterNoAction) Hit(eventType ws.EventType, chatUser *accountFiee.ChatUserData, wsClient *ws.Client, record *accountFiee.ChatRecordData, robotInfo *accountFiee.ChatUserData) (hit bool, runTime time.Time, logic func(msg string) error) {
 	runTime = time.Now().Add(k.DelaySecond * time.Second)
-	logic = func() error {
+	logic = func(msg string) error {
 		var notice = dto.MessageListType{}
 		notice.BuildMessage(record)
-		_, err := consts.ChatRoom.SendSessionMessage(robotId, record.SessionId, ws.NewChatMsgType, notice)
+		_, err := consts.ChatRoom.SendSessionMessage(robotInfo, record.SessionId, ws.NewChatMsgType, notice)
 		return err
 	}
 	return
diff --git a/pkg/service/asChat/robot/robot.go b/pkg/service/asChat/robot/robot.go
index 63d0e39..9a56c3a 100644
--- a/pkg/service/asChat/robot/robot.go
+++ b/pkg/service/asChat/robot/robot.go
@@ -157,7 +157,7 @@ func (r *Robot) Run() {
 				if now.After(task.RunTime) {
 					// 执行任务
 					go func() {
-						err := task.Run()
+						err := task.Run(task.Response)
 						if err != nil {
 							log.Printf("聊天机器人[%d]回复消息失败:%v", r.Info.ID, err)
 						}
@@ -176,10 +176,10 @@ func (r *Robot) Run() {
 			switch event.EventType {
 			case ws.EventUserJoin: //用户加入聊天室
 				for _, ruleResponse := range r.Rules {
-					hit, runtime, logic := ruleResponse.Hit(ws.EventUserJoin, nil, r.Info.ID)
+					hit, runtime, logic := ruleResponse.Hit(ws.EventUserJoin, nil, event.Client, nil, r.Info)
 					if hit {
 						if runtime.IsZero() {
-							err := logic()
+							err := logic(ruleResponse.Response)
 							if err != nil {
 								log.Printf("robot 执行任务失败:%v\n", err)
 							}
@@ -195,10 +195,10 @@ func (r *Robot) Run() {
 				}
 			case ws.EventChatMessage:
 				for _, ruleResponse := range r.Rules {
-					hit, runtime, logic := ruleResponse.Hit(ws.EventChatMessage, nil, r.Info.ID)
+					hit, runtime, logic := ruleResponse.Hit(ws.EventUserJoin, nil, event.Client, nil, r.Info)
 					if hit {
 						if !runtime.IsZero() {
-							err := logic()
+							err := logic(ruleResponse.Response)
 							if err != nil {
 								log.Printf("robot 执行任务失败:%v\n", err)
 							}
diff --git a/pkg/service/asChat/robot/task.go b/pkg/service/asChat/robot/task.go
index 78c5e78..94644d7 100644
--- a/pkg/service/asChat/robot/task.go
+++ b/pkg/service/asChat/robot/task.go
@@ -10,6 +10,6 @@ import "time"
 
 type RobotTask struct {
 	RunTime  time.Time
-	Run      func() error
+	Run      func(msg string) error
 	Response string
 }