手写rpc框架2


参考视频: 自己动手实现RPC框架,本内容为视频学习笔记。

三篇内容:

  • 理论篇:rpc核心原理、现有框架对比、相关技术

  • 实战篇:代码实现、使用案例

  • 总结篇

上节:手写rpc框架1

实战篇

序列化

fastjson。

image-20220125165723978

/**
 * 反序列化
 * @author hqingLau
 **/
public interface Decoder {
    <T> T decode(byte[] bytes,Class<T> clazz);
}


/**
 * 序列化
 * @author hqingLau
 **/
public interface Encoder {
    byte[] encode(Object obj);
}


/**
 * 基于json的反序列化实现
 * @author hqingLau
 **/
public class JSONDecoder implements Decoder{

    @Override
    public <T> T decode(byte[] bytes, Class<T> clazz) {
        return JSON.parseObject(bytes,clazz);
    }
}


/**
 * 基于json的序列化实现
 * @author hqingLau
 **/
public class JSONEncoder implements Encoder{

    @Override
    public byte[] encode(Object obj) {
        return JSON.toJSONBytes(obj);
    }
}

在pom文件中,加入依赖,正如第一节所说,在总项目下面pom里dependencyManagement给定了版本号之后,子模块无需加入版本号。

<dependencies>
    <dependency>
        <groupId>com.alibaba</groupId>
        <artifactId>fastjson</artifactId>
    </dependency>
</dependencies>

网络传输

image-20220125174136775

该部分需要用到proto模块,加入依赖:

 <dependencies>
     <dependency>
         <groupId>commons-io</groupId>
         <artifactId>commons-io</artifactId>
     </dependency>
     <dependency>
         <groupId>org.eclipse.jetty</groupId>
         <artifactId>jetty-servlet</artifactId>
     </dependency>
     <dependency>
         <groupId>cn.orzlinux</groupId>
         <artifactId>cn-orzlinux-rpc-proto</artifactId>
         <version>${project.version}</version>
     </dependency>
</dependencies>

名字就能看出class的用途,代码:

/**
 * 1、启动,监听端口
 * 2、接收请求
 * 3、关闭监听
 * @author hqingLau
 **/
public interface TransportServer {
    void init(int port,RequestHandler handler);
    void start();
    void stop();
}

/**
 * 1、创建连接
 * 2、发送数据,并且等待响应
 * 3、关闭连接
 * @author hqingLau
 **/
public interface TransportClient {
    void connect(Peer peer);
    // 写完数据等待响应
    InputStream write(InputStream data);
    void close();
}

/**
 * 处理网络请求的handler
 * @author hqingLau
 **/
public interface RequestHandler {
    void onRequest(InputStream recive, OutputStream to);
}



@Slf4j
public class HTTPTransportServer implements TransportServer{
    private RequestHandler handler;
    private Server server;

    @Override
    public void init(int port, RequestHandler handler) {
        this.handler = handler;
        this.server = new Server(port);
        //servlet 接受请求
        ServletContextHandler ctx = new ServletContextHandler();
        server.setHandler(ctx);

        // holder: jetty处理网络请求时候的抽象
        ServletHolder holder = new ServletHolder(new RequestServlet());
        ctx.addServlet(holder,"/*");
    }

    @Override
    public void start() {
        try {
            server.start();
            server.join();
        } catch (Exception e) {
            log.error(e.getMessage(),e);
        }
    }

    @Override
    public void stop() {
        try {
            server.stop();
        } catch (Exception e) {
            log.error(e.getMessage(),e);
        }
    }

    class RequestServlet extends HttpServlet {
        @Override
        protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
            log.info("client connect");

            InputStream in = req.getInputStream();
            OutputStream out = resp.getOutputStream();

            if(handler!=null) {
                handler.onRequest(in,out);
            }
            out.flush();
        }
    }
}


public class HTTPTransportClient implements TransportClient{
    private String url;

    @Override
    public void connect(Peer peer) {
        this.url = "http://"+peer.getHost()+
                ":"+peer.getPort();
    }

    @Override
    public InputStream write(InputStream data) {
        try {
            HttpURLConnection httpURLConn =
                    (HttpURLConnection) new URL(url).openConnection();
            httpURLConn.setDoInput(true);
            httpURLConn.setDoOutput(true);
            httpURLConn.setUseCaches(false);
            httpURLConn.setRequestMethod("POST");

            httpURLConn.connect();
            IOUtils.copy(data,httpURLConn.getOutputStream());

            int resultCode = httpURLConn.getResponseCode();
            if(resultCode==HttpURLConnection.HTTP_OK) {
                return httpURLConn.getInputStream();
            } else {
                return httpURLConn.getErrorStream();
            }
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override
    public void close() {

    }
}

Server

image-20220125190635763


/**
 * 管理rpc暴露的服务
 * @author hqingLau
 **/
@Slf4j
public class ServiceManager {
    private Map<ServiceDescriptor,ServiceInstance> services;
    public ServiceManager() {
        this.services = new ConcurrentHashMap<>();
    }

    public <T> void register(Class<T> interfaceClass,Object bean) {
        Method[] methods = ReflectUtils.getPublicMethods(interfaceClass);
        for(Method method:methods) {
            ServiceInstance sis = new ServiceInstance(bean,method);
            ServiceDescriptor sdp = ServiceDescriptor.from(interfaceClass,method);
            services.put(sdp,sis);
            log.info("register service:{} {}",sdp.getClazz(),sdp.getMethod());
        }
    }

    public ServiceInstance lookup(Request request) {
        ServiceDescriptor sdp = request.getServiceDescriptor();
        return services.get(sdp);
    }
}

/**
 * 调用service的实际方法
 * @author hqingLau
 **/
public class ServiceInvoker {
    public Object invoke(ServiceInstance serviceInstance, Request request) {
        return ReflectUtils.invoke(serviceInstance.getTarget(),
                serviceInstance.getMethod(),
                request.getParameters());
    }
}


/**
 * 具体的server服务
 * @author hqingLau
 **/
@Data
@AllArgsConstructor
public class ServiceInstance {
    private Object target;
    private Method method;
}


/**
 * server配置
 * @author hqingLau
 **/
@Data
public class RpcServerConfig {
    //网络模块、序列化模块、监听端口
    private final Class<? extends TransportServer> transportClass = HTTPTransportServer.class;
    private final Class<? extends Encoder> encoderClass = JSONEncoder.class;
    private final Class<? extends Decoder> decoderClass = JSONDecoder.class;
    private final int port = 3000;
}


@Slf4j
public class RpcServer {
    private RpcServerConfig config;
    private TransportServer net;
    private Encoder encoder;
    private Decoder decoder;
    private ServiceManager serviceManager;
    private ServiceInvoker serviceInvoker;
    private RequestHandler handler = new RequestHandler() {
        @Override
        public void onRequest(InputStream recive, OutputStream to) {
            Response resp = new Response();

            try {
                byte[] inBytes = IOUtils.readFully(recive,recive.available());
                Request request = decoder.decode(inBytes,Request.class);
                log.info("get request: {}",request);

                ServiceInstance sis = serviceManager.lookup(request);
                Object ret = serviceInvoker.invoke(sis,request);
                resp.setData(ret);
            } catch (IOException e) {
                log.warn(e.getMessage(),e);
                resp.setCode(1);
                resp.setMessage("RpcServer got error: "+
                        e.getClass().getName()+" "+
                        e.getMessage());
            } finally {
                try {
                    byte[] outBytes = encoder.encode(resp);
                    to.write(outBytes);
                    log.info("rpc send data");
                } catch (IOException e) {
                    log.warn(e.getMessage(),e);
                }
            }
        }
    };

    public RpcServer(RpcServerConfig config) {
        this.config = config;
        this.net = ReflectUtils.newInstance(config.getTransportClass());
        this.net.init(config.getPort(), this.handler);
        this.encoder = ReflectUtils.newInstance(config.getEncoderClass());
        this.decoder = ReflectUtils.newInstance(config.getDecoderClass());
        this.serviceManager = new ServiceManager();
        this.serviceInvoker = new ServiceInvoker();
    }

    public <T> void register(Class<T> interfaceClass, Object bean) {
        serviceManager.register(interfaceClass, bean);
    }

    public void start() {
        this.net.start();
    }

    public void stop() {
        this.net.stop();
    }
}

参考文献

司马极客视频