diff --git a/cmd/make.go b/cmd/make.go index 61a1c39..0ee0948 100644 --- a/cmd/make.go +++ b/cmd/make.go @@ -26,7 +26,8 @@ var ( }, Examples: "make -m act -i 1: 创建活动1的接口与服务文件 \n" + "make -m logic -n test: 创建test的服务文件 \n" + - "make -m config -n test: 创建配置文件", + "make -m config -n test: 创建配置文件 \n" + + "make -m socket -n test: 创建socket文件 \n", Func: func(ctx context.Context, parser *gcmd.Parser) (err error) { //g.Dump(parser.GetOptAll(), parser.GetArgAll()) @@ -53,6 +54,12 @@ var ( return } err = this.Config(name) + case "socket": + var name = parser.GetOpt("name").String() + if name == "" { + return + } + err = this.Socket(name) } return @@ -116,3 +123,25 @@ func (c *cMake) Config(name string) (err error) { return } + +func (c *cMake) Socket(name string) (err error) { + var filePath = fmt.Sprintf("internal/socket/%s/%s_new.go", name, gstr.CaseSnake(name)) + //生成文件不覆盖 + if !gfile.Exists(filePath) { + // 生成目录文件 + get, _ := fs.ReadFile(ConfigFiles, "make/socket") + fileStr := string(get) + fileStr = gstr.Replace(fileStr, "{name}", name) + err = gfile.PutContents(filePath, fileStr) + + //生成方法文件 + var filePath2 = fmt.Sprintf("internal/socket/%s/%s.go", name, gstr.CaseSnake(name)) + get, _ = fs.ReadFile(ConfigFiles, "make/socket2") + fileStr = string(get) + fileStr = gstr.Replace(fileStr, "{name}", name) + fileStr = gstr.Replace(fileStr, "{func}", gstr.CaseCamel(name)) + err = gfile.PutContents(filePath2, fileStr) + } + + return +} diff --git a/cmd/make/socket b/cmd/make/socket new file mode 100644 index 0000000..9919a22 --- /dev/null +++ b/cmd/make/socket @@ -0,0 +1,10 @@ +package {name} + +type {name} struct { +} + +func New() *{name} { + return &{name}{} +} + +func init() {} diff --git a/cmd/make/socket2 b/cmd/make/socket2 new file mode 100644 index 0000000..6f170be --- /dev/null +++ b/cmd/make/socket2 @@ -0,0 +1,19 @@ +package {name} + +import ( + "github.com/ayflying/utility_go/pkg" + "github.com/ayflying/utility_go/pkg/websocket" + "google.golang.org/protobuf/proto" +) + +func (s *{name}) {func}Handler(conn *websocket.WebsocketData, req any) (err error) { + var data = &v1.{func}2S{} + err = proto.Unmarshal(req.([]byte), data) + + var res = &v1.{func}2C{} + + resp, err := proto.Marshal(res) + pkg.Websocket().Send(000000, conn.Uid, resp) + + return +} diff --git a/pkg/websocket/registerer.go b/pkg/websocket/registerer.go index 741eaad..a59b558 100644 --- a/pkg/websocket/registerer.go +++ b/pkg/websocket/registerer.go @@ -1,14 +1,24 @@ package websocket +import "google.golang.org/protobuf/proto" + // 定义一个处理方法的类型 type Handler func(conn *WebsocketData, req any) (err error) -type OnConnectHandler func(conn *WebsocketData) +type Handler2 func(conn *WebsocketData) + +type HandlerMessage func(conn *WebsocketData, req any) + +type PbType func(code int32, data []byte) proto.Message +type PbType2 func(data []byte) (int, []byte) // 路由器的处理映射 var ( handlers = make(map[int]Handler) - OnConnectHandlers = make([]OnConnectHandler, 0) - OnCloseHandlers = make([]OnConnectHandler, 0) + OnConnectHandlers = make([]Handler2, 0) + OnCloseHandlers = make([]Handler2, 0) + onMessageHandlers = make([]HandlerMessage, 0) + Byte2Pb = make([]PbType, 0) + Pb2Bytes = make([]PbType2, 0) ) // 注册方法,将某个消息路由器ID和对应的处理方法关联起来 @@ -17,10 +27,23 @@ func (s *SocketV1) RegisterRouter(cmd int, handler Handler) { } //注册方法,讲长连接登陆方法进行注册 -func (s *SocketV1) RegisterOnConnect(_func OnConnectHandler) { +func (s *SocketV1) RegisterOnConnect(_func Handler2) { OnConnectHandlers = append(OnConnectHandlers, _func) } -func (s *SocketV1) RegisterOnClose(_func OnConnectHandler) { +func (s *SocketV1) RegisterOnClose(_func Handler2) { OnCloseHandlers = append(OnCloseHandlers, _func) } + +//注册方法长连接消息体 +func (s *SocketV1) RegisterMessage(_func HandlerMessage) { + onMessageHandlers = append(onMessageHandlers, _func) +} + +func (s *SocketV1) RegisterByte2Pb(_func PbType) { + Byte2Pb = append(Byte2Pb, _func) +} + +func (s *SocketV1) RegisterPb2Byte(_func PbType2) { + Pb2Bytes = append(Pb2Bytes, _func) +} diff --git a/pkg/websocket/socket_new.go b/pkg/websocket/socket_new.go index 9503b5f..2186a23 100644 --- a/pkg/websocket/socket_new.go +++ b/pkg/websocket/socket_new.go @@ -2,13 +2,15 @@ package websocket import ( "context" + "fmt" "github.com/gogf/gf/v2/container/gmap" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/net/ghttp" "github.com/gogf/gf/v2/os/glog" - "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2/util/guid" "github.com/google/uuid" "github.com/gorilla/websocket" + "google.golang.org/protobuf/proto" "strconv" "sync" ) @@ -22,19 +24,19 @@ var ( //Conn map[uuid.UUID]*WebsocketData lock sync.Mutex - m = gmap.New(true) + m = gmap.NewHashMap(true) ) type WebsocketData struct { Ws *websocket.Conn - Uuid uuid.UUID + Uuid string Uid int64 Ctx context.Context } func NewV1() *SocketV1 { return &SocketV1{ - Type: 1, + Type: 2, } } @@ -77,11 +79,9 @@ func (s *SocketV1) Load(serv *ghttp.Server, prefix string) { // @receiver s // @param conn func (s *SocketV1) OnConnect(ctx context.Context, conn *websocket.Conn) { - //lock.Lock() - //defer lock.Unlock() defer conn.Close() - id, _ := uuid.NewUUID() + id := guid.S() ip := conn.RemoteAddr().String() data := &WebsocketData{ Uuid: id, @@ -103,8 +103,10 @@ func (s *SocketV1) OnConnect(ctx context.Context, conn *websocket.Conn) { for { //进入当前连接线程拥堵 msgType, msg, err := conn.ReadMessage() + s.Type = msgType if err != nil { //客户端断开返回错误,断开当前连接 + //g.Log().Error(ctx, err) break } s.OnMessage(m.Get(id).(*WebsocketData), msg, msgType) @@ -121,53 +123,142 @@ func (s *SocketV1) OnConnect(ctx context.Context, conn *websocket.Conn) { // @param msg // @param msgType func (s *SocketV1) OnMessage(conn *WebsocketData, req []byte, msgType int) { - s.Type = msgType - //g.Log().Debugf(ctx, "收到消息:%v,type=%v,conn=%v", string(req), msgType, conn) - //s.Send(conn.Uuid, msg) - //s.SendAll(msg) + s.Type = 2 + + var cmd int + var msg []byte + //uid := conn.Uid + for _, v := range Pb2Bytes { + cmd, msg = v(req) + } + //msgStr := string(req) - msg := req[8:] - cmd := gconv.Int(req[:8]) - //GetRouter(cmd, conn.Uid, msg) + //cmd = gconv.Int(msgStr[:8]) + //msg = []byte(msgStr[8:]) + handler, exist := handlers[cmd] if exist { //匹配上路由器 err := handler(conn, msg) - g.Log().Error(conn.Ctx, err) + if err != nil { + g.Log().Error(conn.Ctx, err) + } } else { //fmt.Println("未注册的路由器ID:", cmd) - s.Send(conn.Uuid, []byte("未注册的协议号:"+strconv.Itoa(cmd))) + s.Send(20000000, conn.Uid, []byte("未注册的协议号:"+strconv.Itoa(cmd))) s.OnClose(conn) return } } +//绑定用户编号 +func (s *SocketV1) BindUid(conn *WebsocketData, uid int64) { + lock.Lock() + defer lock.Unlock() + + cacheKey := fmt.Sprintf("socket:uid:%d", uid) + g.Redis().Set(nil, cacheKey, conn.Uuid) + + if conn.Uid == 0 { + conn.Uid = uid + } + +} + +//解绑用户 +func (s *SocketV1) UnBindUid(uid int64) { + lock.Lock() + defer lock.Unlock() + + cacheKey := fmt.Sprintf("socket:uid:%d", uid) + g.Redis().Del(nil, cacheKey) +} + +// Uid2Uuid 用户编号转uuid唯一标识 +func (s *SocketV1) Uid2Uuid(uid int64) (uuid string) { + cacheKey := fmt.Sprintf("socket:uid:%d", uid) + get, _ := g.Redis().Get(nil, cacheKey) + if get.IsNil() { + return + } + + uuid = get.String() + + //如果不在线了 + if !m.Contains(uuid) { + // 解绑用户编号 + s.UnBindUid(uid) + return + } + + return +} + +// SendUuid +// +// @Description: +// @receiver s +// @param uid +// @param data +func (s *SocketV1) SendUuid(code int32, id uuid.UUID, data []byte) { + if !m.Contains(id) { + return + } + + conn := m.Get(id).(*WebsocketData) + + //前置方法 + + for _, v := range Byte2Pb { + temp := v(code, data) + data, _ = proto.Marshal(temp) + } + + conn.Ws.WriteMessage(s.Type, data) + + return +} + // Send // // @Description: // @receiver s // @param uid // @param data -// @return err -func (s *SocketV1) Send(id uuid.UUID, data []byte) (err error) { - if !m.Contains(id) { +func (s *SocketV1) Send(code int32, uid int64, data []byte) { + uuid := s.Uid2Uuid(uid) + if uuid == "" { return } - conn := m.Get(id).(*WebsocketData) + if !m.Contains(uuid) { + return + } + + conn := m.Get(uuid).(*WebsocketData) + + //前置方法 + + for _, v := range Byte2Pb { + temp := v(code, data) + data, _ = proto.Marshal(temp) + } conn.Ws.WriteMessage(s.Type, data) return } // 批量发送 -func (s *SocketV1) SendAll(data []byte) { +func (s *SocketV1) SendAll(code int32, data []byte) { + for _, v := range Byte2Pb { + temp := v(code, data) + data, _ = proto.Marshal(temp) + } + m.Iterator(func(k interface{}, v interface{}) bool { - //fmt.Printf("%v:%v ", k, v) conn := v.(*WebsocketData) conn.Ws.WriteMessage(s.Type, data) - return true }) } @@ -185,11 +276,14 @@ func (s *SocketV1) OnClose(conn *WebsocketData) { for _, connect := range OnCloseHandlers { connect(conn) } - + uid := conn.Uid // 可能的后续操作: // 1. 更新连接状态或从连接池移除 // 2. 发送通知或清理关联资源 // 3. 执行特定于业务的断开处理 m.Remove(conn.Uuid) + if uid > 0 { + s.UnBindUid(uid) + } conn.Ws.Close() }