diff --git a/pom.xml b/pom.xml index 3049e6f..6d256b2 100644 --- a/pom.xml +++ b/pom.xml @@ -159,6 +159,11 @@ nekohtml 1.9.22 + + org.springframework.boot + spring-boot-starter-websocket + 2.2.1.RELEASE + diff --git a/src/main/java/com/rymcu/vertical/config/WebSocketConfigurer.java b/src/main/java/com/rymcu/vertical/config/WebSocketConfigurer.java new file mode 100644 index 0000000..8c935ad --- /dev/null +++ b/src/main/java/com/rymcu/vertical/config/WebSocketConfigurer.java @@ -0,0 +1,17 @@ +package com.rymcu.vertical.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.server.standard.ServerEndpointExporter; + +/** + * WebSocket 配置类 + * @author ronger + */ +@Configuration +public class WebSocketConfigurer { + @Bean + public ServerEndpointExporter serverEndpointExporter() { + return new ServerEndpointExporter(); + } +} diff --git a/src/main/java/com/rymcu/vertical/web/api/common/WebSocketServer.java b/src/main/java/com/rymcu/vertical/web/api/common/WebSocketServer.java new file mode 100644 index 0000000..dc93ac6 --- /dev/null +++ b/src/main/java/com/rymcu/vertical/web/api/common/WebSocketServer.java @@ -0,0 +1,130 @@ +package com.rymcu.vertical.web.api.common; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +import javax.websocket.*; +import javax.websocket.server.PathParam; +import javax.websocket.server.ServerEndpoint; +import java.io.IOException; +import java.util.concurrent.CopyOnWriteArraySet; + +/** + * @author ronger + */ +@ServerEndpoint("/api/v1/websocket/{sid}") +@Component +public class WebSocketServer { + static final Logger log= LoggerFactory.getLogger(WebSocketServer.class); + /** + * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。 + * */ + private static int onlineCount = 0; + /** + * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。 + */ + private static CopyOnWriteArraySet webSocketSet = new CopyOnWriteArraySet(); + + + /** + * 与某个客户端的连接会话,需要通过它来给客户端发送数据 + */ + private Session session; + + /** + * 接收sid + * */ + private String sid = ""; + /** + * 连接建立成功调用的方法*/ + @OnOpen + public void onOpen(Session session,@PathParam("sid") String sid) { + this.session = session; + webSocketSet.add(this); + addOnlineCount(); + log.info("有新窗口开始监听:" + sid + ",当前在线人数为" + getOnlineCount()); + this.sid=sid; + try { + sendMessage("连接成功"); + } catch (IOException e) { + log.error("websocket IO异常"); + } + } + + /** + * 连接关闭调用的方法 + */ + @OnClose + public void onClose() { + webSocketSet.remove(this); + subOnlineCount(); + log.info("有一连接关闭!当前在线人数为" + getOnlineCount()); + } + + /** + * 收到客户端消息后调用的方法 + * + * @param message 客户端发送过来的消息*/ + @OnMessage + public void onMessage(String message, Session session) { + log.info("收到来自窗口" + sid + "的信息:" + message); + //群发消息 + for (WebSocketServer item : webSocketSet) { + try { + item.sendMessage(message); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + /** + * + * @param session + * @param error + */ + @OnError + public void onError(Session session, Throwable error) { + log.error("发生错误"); + error.printStackTrace(); + } + /** + * 实现服务器主动推送 + */ + public void sendMessage(String message) throws IOException { + this.session.getBasicRemote().sendText(message); + } + + + /** + * 群发自定义消息 + * */ + public static void sendInfo(String message,@PathParam("sid") String sid) throws IOException { + log.info("推送消息到窗口"+sid+",推送内容:"+message); + for (WebSocketServer item : webSocketSet) { + try { + //这里可以设定只推送给这个sid的,为null则全部推送 + if(sid==null) { + item.sendMessage(message); + }else if(item.sid.equals(sid)){ + item.sendMessage(message); + } + } catch (IOException e) { + continue; + } + } + } + + public static synchronized int getOnlineCount() { + return onlineCount; + } + + public static synchronized void addOnlineCount() { + WebSocketServer.onlineCount++; + } + + public static synchronized void subOnlineCount() { + WebSocketServer.onlineCount--; + } +}