package com.zehong.framework.web.service;

import com.alibaba.fastjson.JSONArray;
import com.zehong.common.utils.spring.SpringUtils;
import com.zehong.framework.web.domain.server.WebSocketBean;
import com.zehong.system.domain.SysNotice;
import com.zehong.system.service.ISysNoticeService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * @author geng
 * webSocket工具类
 */
@Component
@ServerEndpoint("/webSocket/{roles}/{userId}")
public class WebSocketServer {

    private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class);

    /**
     * 错误最大重试次数
     */
    private static final int MAX_ERROR_NUM = 10;

    /**
     * 用来存放每个客户端对应的webSocket对象。
     */
    private static Map<String,List<WebSocketBean>> webSocketInfo;

    static
    {
        // concurrent包的线程安全map
        webSocketInfo = new ConcurrentHashMap<>();
    }

    @OnOpen
    public void onOpen(Session session, @PathParam("roles") String roles, @PathParam("userId") String userId) {

        for(String role : roles.split(",")){
            WebSocketBean bean = new WebSocketBean();
            bean.setSession(session);
            bean.setUserId(Long.valueOf(userId));
            if(webSocketInfo.containsKey(role)){
                List<WebSocketBean> beans = webSocketInfo.get(role);
                // 连接成功当前对象放入webSocket对象集合
                beans.add(bean);
                sendMessage(bean,initNotice(userId));
                return;
            }

            List<WebSocketBean> beans = new ArrayList<>();
            beans.add(bean);
            webSocketInfo.put(role,beans);
            sendMessage(bean,initNotice(userId));
        }


        log.info("客户端连接服务器session id :"+session.getId()+"，当前连接数：" + webSocketInfo.size());
    }

    private String initNotice(String userId){
        SysNotice notice = new SysNotice();
        notice.setUserId(Long.valueOf(userId));
        ISysNoticeService sysNoticeService = SpringUtils.getBean(ISysNoticeService.class);
        List<SysNotice> notices = sysNoticeService.selectNoticeList(notice);
        if(CollectionUtils.isEmpty(notices)){
            return "";
        }
        return JSONArray.toJSONString(notices);
    }

    @OnClose
    public void onClose(Session session) {
        // 客户端断开连接移除websocket对象
        for (Map.Entry<String, List<WebSocketBean>> entry : webSocketInfo.entrySet()) {
            List<WebSocketBean> beans = entry.getValue().stream().filter(item ->item.getSession().getId().equals(session.getId())).collect(Collectors.toList());
            entry.getValue().removeAll(beans);
        }
        log.info("客户端断开连接，当前连接数：" + webSocketInfo.size());
    }


    @OnMessage
    public void onMessage(Session session, String message) {

        log.info("客户端 session id: "+session.getId()+"，消息:" + message);

        // 此方法为客户端给服务器发送消息后进行的处理，可以根据业务自己处理，这里返回页面
        //sendMessage(session, "服务端返回" + message);

    }

    @OnError
    public void onError(Session session, Throwable throwable) {

        log.error("发生错误"+ throwable.getMessage(),throwable);
    }

    /**
     * 查找发送消息用户
     * @param role 角色
     * @param userId 用户id
     * @param message 消息体
     */
    public void findMessageUser(String role,Long userId, String message) {
        List<WebSocketBean> beans = webSocketInfo.get(role);
        if(!CollectionUtils.isEmpty(beans)){
            //发送给指定角色
            if(null == userId){
                beans.forEach(item ->{
                    sendMessage(item,message);
                });
                return;
            }
            //发送给指定用户
            List<WebSocketBean> userBean = beans.stream().filter(item -> item.getUserId().equals(userId)).collect(Collectors.toList());
            userBean.stream().forEach(item -> {
                sendMessage(item,message);
            });
        }
    }

    /**
     * 发送消息
     * @param bean webSocket对象
     * @param message 消息体
     */
    private void sendMessage(WebSocketBean bean, String message) {
        try{
            // 发送消息
            bean.getSession().getBasicRemote().sendText(message);

            // 清空错误计数
            bean.cleanErrorNum();
        }catch (Exception e){
            log.error("发送消息失败"+ e.getMessage(),e);
            int errorNum = bean.getErroerLinkCount();
            // 小于最大重试次数重发
            if(errorNum <= MAX_ERROR_NUM){
                sendMessage(bean, message);
            }else{
                log.error("发送消息失败超过最大次数");
                // 清空错误计数
                bean.cleanErrorNum();
            }
        }
    }
}
