Rubin's Blog

  • 首页
  • 关于作者
  • 隐私政策
享受恬静与美好~~~
分享生活的点点滴滴~~~
  1. 首页
  2. Spring
  3. 正文

Spring之手写IoC框架

2021年 5月 19日 855点热度 0人点赞 0条评论

前言

本博文旨在记录笔者在手写IoC框架的一个过程以及心得体会,希望对需要的小伙伴有所帮助。

手写IoC框架

什么是IoC/DI?

IoC Inversion of Control (控制反转/反转控制),它是⼀个技术思想,不是⼀个技术实现。IoC思想下开发方式:我们不用自己去new对象了,而是由IoC容器去帮助我们实例化对象并且管理它,我们需要使用哪个对象,去问IoC容器要即可。我们丧失了⼀个权利(创建、管理对象的权利),得到了⼀个福利(不用考虑对象的创建、管理等⼀系列事情)。

IoC和DI描述的是同⼀件事情,只不过角度不⼀样罢了。IoC是站在了对象的角度,将其实例化以及管理的权利交给了(反转)容器。DI是站在了容器的角度,容器会把依赖的对象注入进父对象。

实现思路分析

既然我们要实现一个IoC框架,我们的大体思路就是从容器对象入手,定义好容器对象的属性和行为之后,再思考容器的初始化以及管理bean的生命周期的过程。最后,想一下框架的扩展性,进一步优化。

废话不多说,我们开始吧。

项目准备

我们首先准备一下项目的初始化步骤,也就是定义一下我们的maven坐标,一些工具类的定义等等。

maven坐标:

<dependencies>
    <!-- 单元测试Junit -->
    <dependency>
        <groupId>junit</groupId>
        <artifactId>junit</artifactId>
        <version>4.12</version>
    </dependency>
    <!--dom4j依赖-->
    <dependency>
        <groupId>dom4j</groupId>
        <artifactId>dom4j</artifactId>
        <version>1.6.1</version>
    </dependency>
    <!--xpath表达式依赖-->
    <dependency>
        <groupId>jaxen</groupId>
        <artifactId>jaxen</artifactId>
        <version>1.1.6</version>
    </dependency>
    <!--cglib依赖包-->
    <dependency>
        <groupId>cglib</groupId>
        <artifactId>cglib</artifactId>
        <version>2.1_2</version>
    </dependency>
    <!--lombok依赖-->
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <version>1.16.18</version>
        <scope>provided</scope>
    </dependency>
    <!--reflections-->
    <dependency>
        <groupId>org.reflections</groupId>
        <artifactId>reflections</artifactId>
        <version>0.9.10</version>
    </dependency>
</dependencies>

在初始化我们几个工具类,如下所示:

/**
 * Bean处理工具类
 * Created by rubin on 2021/3/25.
 */
public class BeanUtil {

    /**
     * 处理类名 首字母变小写
     *
     * @param originName
     * @return
     */
    public static String getDefaultBeanName(String originName) {
        if (Character.isLowerCase(originName.charAt(0)))
            return originName;
        else {
            return (new StringBuilder()).append(Character.toLowerCase(originName.charAt(0))).append(originName.substring(1)).toString();
        }
    }

}
/**
 * 集合工具类
 */
public class CollectionUtil {

    /**
     * 集合判空
     *
     * @param collection
     * @return
     */
    public static boolean isEmpty(Collection collection) {
        return null == collection || collection.size() == 0;
    }

}
/**
 * 代理工具类
 */
public class ProxyUtil {

    /**
     * jdk代理类属性名称
     */
    final static String AOP_JAVA__TARGET_FIELD = "h";

    /**
     * cglib代理类属性名称
     */
    final static String AOP_CGLIB_TARGET_FIELD = "CGLIB$CALLBACK_0";

    /**
     * cglib代理类类名特有标识
     */
    public static final String CGLIB_CLASS_SEPARATOR = "$";

    /**
     * 检查是不是jdk动态代理类
     *
     * @param bean
     * @return
     */
    public static boolean checkIsJdkProxy(Object bean) {
        if (bean == null) {
            throw new NullPointerException("bean can not be null");
        }
        return Proxy.isProxyClass(bean.getClass());
    }

    /**
     * 检查是不是cglib动态代理类
     *
     * @param bean
     * @return
     */
    public static boolean checkIsCGlibProxy(Object bean) {
        if (bean == null) {
            throw new NullPointerException("bean can not be null");
        }
        String className = bean.getClass().getName();
        return (className != null && className.contains(CGLIB_CLASS_SEPARATOR));
    }

    /**
     * 获取jdk动态代理类的代理目标类
     *
     * @param proxy
     * @return
     * @throws NoSuchFieldException
     */
    public static Class getJdkProxyTarget(Object proxy) throws NoSuchFieldException, IllegalAccessException {
        Field h = proxy.getClass().getSuperclass().getDeclaredField(AOP_JAVA__TARGET_FIELD);
        h.setAccessible(true);
        JdkProxyHandler jdkProxyHandler = (JdkProxyHandler) h.get(proxy);
        return jdkProxyHandler.getTargetClass();
    }

    /**
     * 获取cglib动态代理对象的代理目标类
     *
     * @param proxy
     * @return
     */
    public static Class getCGlibProxyTarget(Object proxy) throws NoSuchFieldException, IllegalAccessException {
        Field h = proxy.getClass().getDeclaredField(AOP_CGLIB_TARGET_FIELD);
        h.setAccessible(true);
        CGlibProxyMethodInterceptor cGlibProxyMethodInterceptor = (CGlibProxyMethodInterceptor) h.get(proxy);
        return cGlibProxyMethodInterceptor.getTargetClass();
    }

    /**
     * Jdk动态代理
     *
     * @param obj 委托对象
     * @return 代理对象
     */
    public static Object getJdkProxy(Object obj, InvocationHandler invocationHandler) {
        // 获取代理对象
        return Proxy.newProxyInstance(obj.getClass().getClassLoader(), obj.getClass().getInterfaces(), invocationHandler);
    }

    /**
     * Jdk动态代理
     *
     * @param clazz
     * @return
     */
    public static <T> T getJdkProxy(Class<T> clazz, InvocationHandler invocationHandler) {
        // 获取代理对象
        return (T) Proxy.newProxyInstance(clazz.getClassLoader(), clazz.getInterfaces(), invocationHandler);
    }

    /**
     * Jdk动态代理接口
     *
     * @param interfaceClazz
     * @return
     */
    public static <T> T getInterfaceJdkProxy(Class<T> interfaceClazz, InvocationHandler invocationHandler) {
        // 获取接口代理对象
        return (T) Proxy.newProxyInstance(interfaceClazz.getClassLoader(), new Class[]{interfaceClazz}, invocationHandler);
    }

    /**
     * 使用cglib动态代理生成代理对象
     *
     * @param obj 委托对象
     * @return
     */
    public static Object getCglibProxy(Object obj, MethodInterceptor methodInterceptor) {
        return Enhancer.create(obj.getClass(), methodInterceptor);
    }

    /**
     * 使用cglib动态代理生成代理对象
     *
     * @param clazz 委托对象
     * @return
     */
    public static <T> T getCglibProxy(Class<T> clazz, MethodInterceptor methodInterceptor) {
        return (T) Enhancer.create(clazz, methodInterceptor);
    }

}
/**
 * 反射工具类
 */
public class ReflectUtil {

    /**
     * 创建实体
     *
     * @param classPath
     * @return
     * @throws ClassNotFoundException
     * @throws IllegalAccessException
     * @throws InstantiationException
     */
    public static Object createInstance(String classPath) {
        Object o;
        try {
            Class clazz = Class.forName(classPath);
            o = clazz.newInstance();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
            throw new RuntimeException("the class " + classPath + " can not be found");
        } catch (IllegalAccessException e) {
            e.printStackTrace();
            throw new RuntimeException("load class " + classPath + " fail");
        } catch (InstantiationException e) {
            e.printStackTrace();
            throw new RuntimeException("load class " + classPath + " fail");
        }
        return o;
    }

    /**
     * 创建类实例
     *
     * @param clazz
     * @return
     */
    public static Object createInstance(Class clazz) {
        Object o;
        try {
            o = clazz.newInstance();
        } catch (InstantiationException e) {
            e.printStackTrace();
            throw new RuntimeException("load class " + clazz.getCanonicalName() + " fail");
        } catch (IllegalAccessException e) {
            e.printStackTrace();
            throw new RuntimeException("load class " + clazz.getCanonicalName() + " fail");
        }
        return o;
    }

    /**
     * 获取Class类
     *
     * @param classPath
     * @return
     */
    public static Class getClass(String classPath) {
        try {
            return Class.forName(classPath);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
            throw new RuntimeException("the class " + classPath + "can not be found");
        }
    }

    /**
     * 注入属性值
     *
     * @param o
     * @param fieldName
     * @param fieldValue
     */
    public static void injectFieldValue(Object o, String fieldName, Object fieldValue) {
        try {
            Field field = o.getClass().getDeclaredField(fieldName);
            field.setAccessible(true);
            field.set(o, fieldValue);
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
            throw new RuntimeException("there are no field named " + fieldName + " in class " + o.getClass().getName());
        } catch (IllegalAccessException e) {
            e.printStackTrace();
            throw new RuntimeException("inject field value fail : " + fieldName);
        }
    }

    /**
     * set注入属性值
     *
     * @param o
     * @param fieldName
     * @param fieldValue
     */
    public static void setFieldValue(Object o, String fieldName, Object fieldValue) {
        Method[] declaredMethods = o.getClass().getDeclaredMethods();
        for (int i = 0; i < declaredMethods.length; i++) {
            if (declaredMethods[i].getName().equalsIgnoreCase("set" + fieldName)) {
                try {
                    declaredMethods[i].invoke(o, new Object[]{fieldValue});
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                    throw new RuntimeException("set field value fail : " + fieldName);
                } catch (InvocationTargetException e) {
                    e.printStackTrace();
                    throw new RuntimeException("set field value fail : " + fieldName);
                }
            }
        }
    }

    /**
     * 获取接口所有实现类
     *
     * @param interfaceClass
     * @return
     */
    public static Set<Class> getAllSubTypeOf(Class interfaceClass) {
        if (!interfaceClass.isInterface()) {
            throw new RuntimeException("the class must be interface");
        }
        Reflections reflections = new Reflections("com.rubin.spring");
        return reflections.getSubTypesOf(interfaceClass);
    }

    /**
     * 获取某包下面带有某个注解的所有接口
     *
     * @param scanPath
     * @return
     */
    public static Set<Class> getAllInterfaceWithAnnotation(String scanPath, Class annotationClass) {
        if (StringUtil.EMPTY.equals(scanPath)) {
            return new HashSet<>();
        }
        Reflections reflections = new Reflections(new ConfigurationBuilder().setUrls(ClasspathHelper.forPackage(scanPath)).setScanners(new FieldAnnotationsScanner(), new TypeAnnotationsScanner(), new SubTypesScanner(false)));
        Set<Class> allTypes = reflections.getTypesAnnotatedWith(annotationClass);
        if (CollectionUtil.isEmpty(allTypes)) {
            return new HashSet<>();
        }
        Set<Class> collect = allTypes.stream().filter(Class::isInterface).collect(Collectors.toSet());
        return collect;
    }

    /**
     * 获取某包下面所有的带有某注解类的类类型,包括子注解
     *
     * @param scanPath
     * @param annotationClass
     * @return
     */
    public static Set<Class> getAllObjectClassWithAnnotationOrSubAnnotation(String scanPath, Class annotationClass) {
        Set<String> classPaths = new HashSet<>();
        doGetAllObjectClassPaths(classPaths, scanPath);
        Set<Class> classes = classPaths.stream()
                .map(ReflectUtil::getClass)
                .filter(clazz -> ReflectUtil.hasAnnotationOrSubAnnotation(clazz, annotationClass))
                .collect(Collectors.toSet());
        return classes;
    }

    /**
     * 获取该路径下的所有类的路径
     *
     * @param classPaths
     * @param scanPath
     */
    private static void doGetAllObjectClassPaths(Set<String> classPaths, String scanPath) {
        String folderPath = Thread.currentThread().getContextClassLoader().getResource(StringUtil.EMPTY).getPath() + scanPath.replace(".", File.separator);
        File folder = new File(folderPath);
        File[] files = folder.listFiles();
        if (files == null) {
            return;
        }
        for (int i = 0; i < files.length; i++) {
            if (files[i].isDirectory()) {
                doGetAllObjectClassPaths(classPaths, scanPath + "." + files[i].getName());
            } else if (files[i].getName().endsWith(".class")) {
                classPaths.add(scanPath + "." + files[i].getName().replace(".class", StringUtil.EMPTY));
            }
        }
    }

    /**
     * 校验改类上面有没有该注解或者子注解
     *
     * @param clazz
     * @param annotationClass
     * @return
     */
    public static boolean hasAnnotationOrSubAnnotation(Class clazz, Class annotationClass) {
        if (clazz.isAnnotationPresent(annotationClass)) {
            return true;
        }
        Annotation[] annotations = clazz.getAnnotations();
        for (int i = 0; i < annotations.length; i++) {
            if ( annotations[i].annotationType().isAnnotationPresent(annotationClass)) {
                return true;
            }
        }
        return false;
    }

    /**
     * 获取clazz的接口集合
     *
     * @param clazz
     * @return
     */
    public static List<Class> getInterfaces(Class clazz) {
        List<Class> interfaces = new ArrayList<>();
        Class[] interfaceArr = clazz.getInterfaces();
        if (interfaceArr != null && interfaceArr.length > 0) {
            interfaces = Arrays.asList(interfaceArr);
        }
        return interfaces;
    }

    /**
     * 获取子注解
     *
     * @param clazz
     * @param annotationClass
     * @return
     */
    public static Annotation getSubAnnotation(Class clazz, Class annotationClass) {
        Annotation[] annotations = clazz.getAnnotations();
        for (int i = 0; i < annotations.length; i++) {
            if (annotations[i].annotationType().isAnnotationPresent(annotationClass)) {
                return annotations[i];
            }
        }
        throw new RuntimeException("there are no subType matched the annotation " + annotationClass.getCanonicalName());
    }

}
/**
 * 字符串工具类
 */
public class StringUtil {

    /**
     * 空串常量
     */
    public static final String EMPTY = "";

    /**
     * 字符串判空
     *
     * @param s
     * @return
     */
    public static boolean isBlank(String s) {
        return null == s || StringUtil.EMPTY.equals(s);
    }

    /**
     * 字符串集合是否有空字符判断
     *
     * @param stringArr
     * @return
     */
    public static boolean isAnyBlank(String... stringArr) {
        for (int i = 0; i < stringArr.length; i++) {
            if (StringUtil.isBlank(stringArr[i])) {
                return true;
            }
        }
        return false;
    }

    /**
     * 字符串集合是否都不为空
     *
     * @param stringArr
     * @return
     */
    public static boolean isNoneBlank(String... stringArr) {
        return !StringUtil.isAnyBlank(stringArr);
    }

    /**
     * 字符串判非空
     *
     * @param s
     * @return
     */

    public static boolean isNotBlank(String s) {
        return !StringUtil.isBlank(s);
    }

}

我们容器中的bean有可能是代理对象,所以我们定义两个接口来辅助判断代理类的代理目标的类类型,供容器判断使用:

/**
 * jdk代理处理接口
 */
public interface JdkProxyHandler {

    /**
     * 获取代理目标类类型
     *
     * @return
     */
    Class<?> getTargetClass();

}
/**
 * CGLIB拦截器处理接口
 * Created by rubin on 2021/3/25.
 */
public interface CGlibProxyMethodInterceptor {

    /**
     * 获取代理目标类的类类型
     *
     * @return
     */
    Class getTargetClass();

}

最后,我们定义一个文件读取类:

/**
 * 资源读取类
 */
public class Resources {

    /**
     * 读取资源为输入流
     *
     * @param filePath
     * @return
     */
    public static InputStream readFileAsInputStream(String filePath) {
        return Resources.class.getClassLoader().getResourceAsStream(filePath);
    }

}

至此,我们的准备工作便完成了。接下来就是我们的框架开发了。

初始化容器对象

首先,我们应该有一个顶级容器,该容器定义一些容器通用的方法:

/**
 * Bean工厂 负责加载定义的bean实例 并提供获取bean的接口
 */
public interface BeanFactory {

    /**
     * 通过name获取bean
     *
     * @param beanName
     * @return
     */
    Object getBean(String beanName);

    /**
     * 通过类型获取bean
     *
     * @param beanClass
     * @return
     */
    Object getBean(Class<?> beanClass);

    /**
     * 包含某个bean
     *
     * @param beanName
     * @return
     */
    boolean containsBean(String beanName);

}

上述容器为容器的顶级接口,分别定义了通过名称获取bean、通过类型获取bean(此种方式,我们只支持一个类型的bean只存在一个,也就是说,一个接口只允许一个实现类,此功能是后期的扩展点)、容器通过名称判断是否含有对应bean三个接口。

为了扩展我们后面的容器初始化,我们在定义一个子接口,提供bean的注册功能:

public interface AutowiredBeanFactory extends BeanFactory {

    /**
     * 注册Bean对象
     */
    void registerBean(String beanName, Object beanObject);

    /**
     * 准备实例化所有单例对象
     */
    void preInstantiateSingletons();

}

接下来定义默认实现类:

public class DefaultAutowiredBeanFactory implements AutowiredBeanFactory {

    /**
     * 一级缓存 即单例池
     */
    private Map<String, Object> singletonObjects;

    /**
     * 二级缓存 存放提前暴露的对象实体,解决循环依赖问题
     */
    private Map<String, Object> objectMappings;

    /**
     * bean的id和定义实体的映射表
     */
    private Map<String, BeanDefinition> beanDefinitionMap;

    /**
     * bean定义加载器
     */
    private BeanDefinitionLoader beanDefinitionLoader;

    public DefaultAutowiredBeanFactory(BeanDefinitionLoader beanDefinitionLoader) {
        this.singletonObjects = new ConcurrentHashMap<>();
        this.objectMappings = new ConcurrentHashMap<>();
        this.beanDefinitionMap = new ConcurrentHashMap<>();
        this.beanDefinitionLoader = beanDefinitionLoader;
        this.beanDefinitionMap.putAll(this.beanDefinitionLoader.load());
    }

    /**
     * 通过name获取bean
     *
     * @param beanName
     * @return
     */
    @Override
    public Object getBean(String beanName) {
        Object bean = singletonObjects.get(beanName);
        if (bean == null) {
            BeanDefinition beanDefinition = getBeanDefinitionByName(beanName);
            bean = createBean(beanDefinition);
        }
        return bean;
    }

    /**
     * 通过类型获取bean
     * 同一类型只支持一个bean存在 如果有多个bean将报错 接口同理
     *
     * @param beanClass
     * @return
     */
    @Override
    public Object getBean(Class<?> beanClass) {
        Object bean;
        // 接口类型  获取其子实现类 多个子实现类将报错
        if (beanClass.isInterface()) {
            bean = checkAndGetImplObject(beanClass);
        } else {
            bean = checkAndGetObject(beanClass);
        }
        if (bean == null) {
            BeanDefinition beanDefinition = getBeanDefinitionByClass(beanClass);
            bean = createBean(beanDefinition);
        }
        return bean;
    }

    /**
     * 包含某个bean
     *
     * @param beanName
     * @return
     */
    @Override
    public boolean containsBean(String beanName) {
        return singletonObjects.containsKey(beanName);
    }

    /**
     * 注册Bean对象
     *
     * @param beanName
     * @param beanObject
     */
    @Override
    public synchronized void registerBean(String beanName, Object beanObject) {
        beanObject = triggerAllBeanPostProcessors(beanName, beanObject);
        if (containsBean(beanName)) {
            throw new RuntimeException("there are same bean named " + beanName);
        }
        singletonObjects.put(beanName, beanObject);
    }

    /**
     * 准备实例化所有单例对象
     */
    @Override
    public void preInstantiateSingletons() {
        if (beanDefinitionMap.size() == 0) {
            return;
        }
        for (BeanDefinition beanDefinition : beanDefinitionMap.values()) {
            createBean(beanDefinition);
        }
    }

    /**
     * 创建bean实体并装配 装配完毕自动注入容器中
     *
     * @param beanDefinition
     * @return
     */
    private Object createBean(BeanDefinition beanDefinition) {
        Object beanObject;
        if ((beanObject = singletonObjects.get(beanDefinition.getBeanName())) != null) {
            return beanObject;
        }
        if ((beanObject = objectMappings.get(beanDefinition.getBeanName())) != null) {
            return beanObject;
        }
        beanObject = doCreateBean(beanDefinition);
        return beanObject;
    }

    /**
     * 执行实体创建
     *
     * @param beanDefinition
     * @return
     */
    private Object doCreateBean(BeanDefinition beanDefinition) {
        if (beanDefinition.getBeanClass().isInterface()) {
            throw new RuntimeException("the interface can not be instanced");
        }
        Object beanObject = ReflectUtil.createInstance(beanDefinition.getBeanClass());
        // 放入二级缓存
        objectMappings.put(beanDefinition.getBeanName(), beanObject);
        fillProperties(beanObject, beanDefinition.getBeanClass(), beanDefinition.getFieldDefinitions());
        invokeInitMethod(beanDefinition, beanObject);
        // 属性填充完毕 移除二级缓存 放入一级缓存
        objectMappings.remove(beanDefinition.getBeanName());
        registerBean(beanDefinition.getBeanName(), beanObject);
        return beanObject;
    }

    /**
     * 执行init-method
     *
     * @param beanDefinition
     * @param beanObject
     */
    private void invokeInitMethod(BeanDefinition beanDefinition, Object beanObject) {
        if (StringUtil.isNotBlank(beanDefinition.getInitMethodName())) {
            Method[] declaredMethods = beanObject.getClass().getDeclaredMethods();
            for (int i = 0; i < declaredMethods.length; i++) {
                if (declaredMethods[i].getName().equals(beanDefinition.getInitMethodName())) {
                    try {
                        declaredMethods[i].invoke(beanObject, new Object[]{});
                    } catch (IllegalAccessException e) {
                        e.printStackTrace();
                    } catch (InvocationTargetException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    /**
     * 填充实体所有定义的属性
     *
     * @param beanObject
     * @param beanClass
     * @param fieldDefinitions
     */
    private void fillProperties(Object beanObject, Class beanClass, List<FieldDefinition> fieldDefinitions) {
        if (CollectionUtil.isEmpty(fieldDefinitions)) {
            return;
        }
        for (FieldDefinition fieldDefinition : fieldDefinitions) {
            doFillProperty(beanObject, fieldDefinition);
        }
    }

    /**
     * 执行填充属性
     *
     * @param beanObject
     * @param fieldDefinition
     */
    private void doFillProperty(Object beanObject, FieldDefinition fieldDefinition) {
        Object fieldValue;
        if (fieldDefinition.isAutoWired()) {
            fieldValue = getBean(fieldDefinition.getFieldClass());
            if (fieldValue == null) {
                if (fieldDefinition.getFieldClass().isInterface()) {
                    fieldValue = createBean(getBeanDefinitionByInterfaceClass(fieldDefinition.getFieldClass()));
                } else {
                    fieldValue = createBean(getBeanDefinitionByClass(fieldDefinition.getFieldClass()));
                }
            }
        } else {
            if (fieldDefinition.isRef()) {
                fieldValue = createBean(beanDefinitionMap.get(fieldDefinition.getRef()));
            } else {
                fieldValue = handleFieldValue(fieldDefinition.getFieldClass(), fieldDefinition.getValue());
            }
        }
        if (fieldDefinition.isSetInject()) {
            ReflectUtil.setFieldValue(beanObject, fieldDefinition.getFieldName(), fieldValue);
            return;
        }
        ReflectUtil.injectFieldValue(beanObject, fieldDefinition.getFieldName(), fieldValue);
    }

    /**
     * 根据接口类型查找定义的bean定义实体
     *
     * @param fieldClass
     * @return
     */
    private BeanDefinition getBeanDefinitionByInterfaceClass(Class fieldClass) {
        List<BeanDefinition> hitBeanDefinitions = beanDefinitionMap.values().stream().filter(beanDefinition -> beanDefinition.getInterfaces().contains(fieldClass)).collect(Collectors.toList());
        if (CollectionUtil.isEmpty(hitBeanDefinitions)) {
            throw new RuntimeException("there are no bean match the class " + fieldClass.getCanonicalName());
        }
        if (hitBeanDefinitions.size() > 1) {
            throw new RuntimeException("there are one more bean match the class " + fieldClass.getCanonicalName());
        }
        return hitBeanDefinitions.get(0);
    }

    /**
     * 通过bean名称获取bean定义
     *
     * @param beanName
     * @return
     */
    private BeanDefinition getBeanDefinitionByName(String beanName) {
        BeanDefinition beanDefinition = beanDefinitionMap.get(beanName);
        if (beanDefinition == null) {
            throw new RuntimeException("there are no bean named " + beanName);
        }
        return beanDefinition;
    }

    /**
     * 根据类型查找定义的bean定义实体
     *
     * @param fieldClass
     * @return
     */
    private BeanDefinition getBeanDefinitionByClass(Class fieldClass) {
        List<BeanDefinition> hitBeanDefinitions = beanDefinitionMap.values().stream()
                .filter(beanDefinition -> beanDefinition.getBeanClass().equals(fieldClass) || beanDefinition.getInterfaces().contains(fieldClass))
                .collect(Collectors.toList());
        if (CollectionUtil.isEmpty(hitBeanDefinitions)) {
            throw new RuntimeException("there are no bean match the class " + fieldClass.getCanonicalName());
        }
        if (hitBeanDefinitions.size() > 1) {
            throw new RuntimeException("there are one more bean match the class " + fieldClass.getCanonicalName());
        }
        return hitBeanDefinitions.get(0);
    }

    /**
     * 处理基本类型
     *
     * @param paramType
     * @param value
     * @return
     */
    private Object handleFieldValue(Class paramType, String value) {
        if (byte.class.equals(paramType)) {
            return Byte.valueOf(value).byteValue();
        } else if (short.class.equals(paramType)) {
            return Short.valueOf(value).shortValue();
        } else if (int.class.equals(paramType)) {
            return Integer.valueOf(value).intValue();
        } else if (long.class.equals(paramType)) {
            return Long.valueOf(value).longValue();
        } else if (double.class.equals(paramType)) {
            return Double.valueOf(value).doubleValue();
        } else if (float.class.equals(paramType)) {
            return Float.valueOf(value).floatValue();
        } else if (char.class.equals(paramType)) {
            return Character.valueOf(value.toCharArray()[0]).charValue();
        } else if (boolean.class.equals(paramType)) {
            return Boolean.valueOf(value).booleanValue();
        } else if (Byte.class.equals(paramType)) {
            return Byte.valueOf(value);
        } else if (Short.class.equals(paramType)) {
            return Short.valueOf(value);
        } else if (Integer.class.equals(paramType)) {
            return Integer.valueOf(value);
        } else if (Long.class.equals(paramType)) {
            return Long.valueOf(value);
        } else if (Double.class.equals(paramType)) {
            return Double.valueOf(value);
        } else if (Float.class.equals(paramType)) {
            return Float.valueOf(value);
        } else if (Character.class.equals(paramType)) {
            return Character.valueOf(value.toCharArray()[0]);
        } else if (Boolean.class.equals(paramType)) {
            return Boolean.valueOf(value);
        } else if (String.class.equals(paramType)) {
            return value;
        }
        return null;
    }

    /**
     * 触发所有的BeanPostProcessor执行逻辑
     *
     * @param beanName
     * @param beanObject
     * @return
     */
    private Object triggerAllBeanPostProcessors(String beanName, Object beanObject) {
        Set<Class> beanPostProcessorClasses = ReflectUtil.getAllSubTypeOf(BeanPostProcessor.class);
        List<BeanPostProcessor> beanPostProcessors = beanPostProcessorClasses.stream()
                .map(ReflectUtil::createInstance)
                .map(o -> (BeanPostProcessor) o)
                .sorted(Comparator.comparing(BeanPostProcessor::getOrder))
                .collect(Collectors.toList());
        if (CollectionUtil.isEmpty(beanPostProcessors)) {
            return beanObject;
        }
        for (BeanPostProcessor beanPostProcessor : beanPostProcessors) {
            beanObject = beanPostProcessor.postProcessBeforeRegister(this, beanName, beanObject);
        }
        return beanObject;
    }

    /**
     * 检查该接口的实现类个数并返回该实现类实体
     * 注意:只可以判断直接实现该接口的类  间接实现的情况未考虑
     *
     * @param interfaceClass
     * @return
     */
    private Object checkAndGetImplObject(Class<?> interfaceClass) {
        int matchCount = 0;
        Object bean = null;
        for (Object value : singletonObjects.values()) {
            Class beanClass = checkProxyAndGetTarget(value);
            if (beanClass.equals(interfaceClass)) {
                bean = value;
                matchCount++;
            } else if (Arrays.asList(value.getClass().getInterfaces()).contains(interfaceClass)) {
                bean = value;
                matchCount++;
            }
        }
        if (matchCount > 1) {
            throw new RuntimeException("there are more than one object match the class");
        }
        return bean;
    }

    /**
     * 检查并返回匹配类型的实体
     *
     * @param beanClass
     * @return
     */
    private Object checkAndGetObject(Class<?> beanClass) {
        int matchCount = 0;
        Object bean = null;
        for (Object value : singletonObjects.values()) {
            Class targetClass = checkProxyAndGetTarget(value);
            if (beanClass.equals(targetClass)) {
                bean = value;
                matchCount++;
            }
        }
        if (matchCount > 1) {
            throw new RuntimeException("there are more than one object match the class");
        }
        return bean;
    }

    /**
     * 检查是否是代理对象 是的话 获取代理目标类
     *
     * @param value
     * @return
     */
    private Class checkProxyAndGetTarget(Object value) {
        if (ProxyUtil.checkIsJdkProxy(value)) {
            try {
                return ProxyUtil.getJdkProxyTarget(value);
            } catch (NoSuchFieldException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
        if (ProxyUtil.checkIsCGlibProxy(value)) {
            try {
                return ProxyUtil.getCGlibProxyTarget(value);
            } catch (NoSuchFieldException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
        return value.getClass();
    }

}

代码较多,但是逻辑比较清晰,读者可以从构造方法往下看,基本的逻辑就是读取bean定义对象并实例化。

这里的知识点比较多,重点如下:

  • 我们是用二级缓存(Spring使用的是三级缓存)的数据结构来保证我们的bean在初始化以及装配的过程中可以支撑循环依赖的场景。
  • 定义了bean的装配逻辑,建议读者细心阅读该部分代码,为后期阅读Spring源码打基础。
  • 将bean定义读取的过程交给加载器去完成,应用了单一职责的思想。

由上述代码,我们可以发现,我们做了第一个扩展,即注册bean的时候会触发bean的后置处理器,允许用户自定义包装我们即将放入容器中的bean。bean后置处理器的定义如下:

/**
 * 在bean实例化并填充之后,注入容器之前自定义包装bean的执行器
 */
public interface BeanPostProcessor {

    /**
     * 在bean注册进容器之前,执行
     *
     * @param beanFactory
     * @param beanName
     * @param beanObject
     * @return
     */
    Object postProcessBeforeRegister(BeanFactory beanFactory, String beanName, Object beanObject);

    /**
     * 执行顺序 升序排列执行
     *
     * @return
     */
    Integer getOrder();

}

初始化上下文对象

所谓的上下文,我们可以理解成我们的一个应用对象,这个里面包含了各种属性:运行环境、运行参数、容器对象、中间件等等。我们先定义一个最简单的上下文对象接口:

/**
 * 上下文顶级接口
 */
public interface ApplicationContext extends BeanFactory {

    void refresh();

}

定义接口的抽象实现类:

public abstract class AbstractApplicationContext implements ApplicationContext {

    /**
     * 容器对象
     */
    private AutowiredBeanFactory autowiredBeanFactory;

    /**
     * bean定义加载器
     */
    private BeanDefinitionLoader beanDefinitionLoader;

    public AbstractApplicationContext(BeanDefinitionLoader beanDefinitionLoader) {
        this.beanDefinitionLoader = beanDefinitionLoader;
        refresh();
    }

    @Override
    public void refresh() {

        AutowiredBeanFactory autowiredBeanFactory = getAutowiredBeanFactory();

        // 准备BeanFactory
        prepareBeanFactory(autowiredBeanFactory);

        // 初始化容器对象之后的逻辑
        postBeanFactory(autowiredBeanFactory);

        // 留给子类实现
        onRefresh();

        // 结束刷新
        finishRefresh();

    }

    /**
     * 刷新时的扩展方法
     */
    protected abstract void onRefresh();


    /**
     * 准备BeanFactory
     *
     * @param autowiredBeanFactory
     */
    private void prepareBeanFactory(AutowiredBeanFactory autowiredBeanFactory) {
        this.autowiredBeanFactory = autowiredBeanFactory;
    }

    /**
     * 初始化容器对象之后的逻辑
     *
     * @param autowiredBeanFactory
     */
    private void postBeanFactory(AutowiredBeanFactory autowiredBeanFactory) {
        // 执行所有的BeanFactoryPostProcessor实现类自定义逻辑
        triggerAllBeanFactoryPostProcessors(autowiredBeanFactory);
    }

    /**
     * 触发所有的BeanFactoryPostProcessor
     *
     * @param autowiredBeanFactory
     */
    private void triggerAllBeanFactoryPostProcessors(AutowiredBeanFactory autowiredBeanFactory) {
        Set<Class> beanFactoryPostProcessorClasses = ReflectUtil.getAllSubTypeOf(BeanFactoryPostProcessor.class);
        List<BeanFactoryPostProcessor> beanFactoryPostProcessors = beanFactoryPostProcessorClasses.stream()
                .map(ReflectUtil::createInstance)
                .map(o -> (BeanFactoryPostProcessor) o)
                .sorted(Comparator.comparing(BeanFactoryPostProcessor::getOrder))
                .collect(Collectors.toList());
        if (CollectionUtil.isEmpty(beanFactoryPostProcessors)) {
            return;
        }
        for (BeanFactoryPostProcessor beanFactoryPostProcessor : beanFactoryPostProcessors) {
            beanFactoryPostProcessor.postProcessBeanFactory(autowiredBeanFactory);
        }
    }

    /**
     * 初始化结束
     */
    private void finishRefresh() {
        autowiredBeanFactory.preInstantiateSingletons();
    }

    /**
     * 获取AutowiredBeanFactory
     *
     * @return
     */
    private AutowiredBeanFactory getAutowiredBeanFactory() {
        if (autowiredBeanFactory != null) {
            return autowiredBeanFactory;
        }
        return new DefaultAutowiredBeanFactory(this.beanDefinitionLoader);
    }

    /**
     * 通过name获取bean
     *
     * @param beanName
     * @return
     */
    @Override
    public Object getBean(String beanName) {
        return autowiredBeanFactory.getBean(beanName);
    }

    /**
     * 通过类型获取bean
     *
     * @param beanClass
     * @return
     */
    @Override
    public Object getBean(Class<?> beanClass) {
        return autowiredBeanFactory.getBean(beanClass);
    }

    /**
     * 包含某个bean
     *
     * @param beanName
     * @return
     */
    @Override
    public boolean containsBean(String beanName) {
        return autowiredBeanFactory.containsBean(beanName);
    }

}

上下文对象实现了容器接口,是一个更加高级的容器对象。它开放了一个容器对象初始化之后的后置处理器,允许用户在容器对象初始化之后注册一些预置的bean对象。

接着,我们定义两个具体实现类来完成我们对上下文对象的初始化定义:

通过读取配置文件的上下文对象(兼容注解扫描,需要在配置文件中打开):

**
 * 配置文件应用上下文容器
 */
public class ClassPathXmlApplicationContext extends AbstractApplicationContext {

    public ClassPathXmlApplicationContext(String configLocation) {
        super(new ClassPathXmlBeanDefinitionLoader(configLocation));
    }

    /**
     * 刷新时的扩展方法
     */
    @Override
    protected void onRefresh() {

    }
}

通过注解扫描的上下文对象:

/**
 * 配置文件应用上下文容器
 */
@Getter
public class AnnotationApplicationContext extends AbstractApplicationContext {

    private String annotationScanBasePackage;

    public AnnotationApplicationContext(String annotationScanBasePackage) {
        super(new AnnotationBeanDefinitionLoader(annotationScanBasePackage));
        this.annotationScanBasePackage = annotationScanBasePackage;
    }

    /**
     * 刷新时的扩展方法
     */
    @Override
    protected void onRefresh() {

    }

}

最后,我们来定义容器的后置处理器接口:

/**
 * 在BeanFactory实例化之后,自定义动作的执行器
 */
public interface BeanFactoryPostProcessor {

    /**
     * 处理BeanFactory初始化之后自定义逻辑
     *
     * @param autowiredBeanFactory
     */
    void postProcessBeanFactory(AutowiredBeanFactory autowiredBeanFactory);

    /**
     * 执行顺序 升序排列执行
     *
     * @return
     */
    Integer getOrder();

}

bean定义加载器

首先,我们先来定义bean的定义对象:

/**
 * Bean定义的信息对象
 */
@Data
public class BeanDefinition implements Serializable {

    private static final long serialVersionUID = -5802994246701645500L;

    private String beanName;

    private Class beanClass;

    private List<Class> interfaces;

    private List<FieldDefinition> fieldDefinitions;

    private String initMethodName;

}

再来定义bean的属性定义实体:

/**
 * bean属性的定义实体
 */
@Data
public class FieldDefinition implements Serializable {

    private static final long serialVersionUID = 459569594700941920L;

    private String fieldName;

    private Class fieldClass;

    private String ref;

    private String value;

    private InjectType injectType;

    private boolean autowired;

    public boolean isRef() {
        return StringUtil.isNotBlank(ref);
    }

    public boolean isSetInject() {
        return InjectType.SET.equals(injectType);
    }

    public boolean isAutoWired() {
        return autowired;
    }

}

定义属性注入类型枚举实体:

/**
 * 注入类型
 */
@AllArgsConstructor
@Getter
public enum InjectType {

    /**
     * set注入
     */
    SET(0),
    /**
     * 属性注入
     */
    FIELD(1);

    private Integer type;

}

再来准备一下bean的扫描注解:

/**
 * 扫描该注解 将其按照类型注入对象中
 * 暂时只支持添加到属性上
 */
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
}
/**
 * 扫描该注解 将其加入单例池中
 * 注意:暂时只支持在类上 加在接口上会报错
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Component {

    String value() default "";

}
/**
 * 扫描该注解 将其加入单例池中
 * 注意:暂时只支持在类上 加在接口上会报错
 */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Component
public @interface Service {

    String value() default "";

}
/**
 * 扫描该注解 将其按照值注入对象字段中
 * 暂时只支持添加到属性上
 */
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Value {

    String value() default "";

}

此处需要注意的点是我们的@Component注解可以标注在注解上,来定义组合注解。我们的加载器会支持此种组合注解的扫描。

bean的定义实体至此就定义完成,我们最后定义一下bean加载器的顶级接口,准备工作便做完了:

/**
 * bean定义加载器
 */
public interface BeanDefinitionLoader {

    Map<String, BeanDefinition> load();

}

配置文件类定义加载器

此加载器通过读取配置文件的配置来读取bean的定义信息(注意:此加载器兼容注解扫描类定义加载器,需要手动在配置文件中打开注解扫描,详情见下文):

/**
 * xml文件bean定义加载器
 */
public class ClassPathXmlBeanDefinitionLoader implements BeanDefinitionLoader {

    private String configLocation;

    private Map<String, BeanDefinition> beanDefinitionMap;

    public ClassPathXmlBeanDefinitionLoader(String configLocation) {
        this.configLocation = configLocation;
        this.beanDefinitionMap = new ConcurrentHashMap<>();
    }

    @Override
    public Map<String, BeanDefinition> load() {
        // 解析xml
        SAXReader saxReader = new SAXReader();
        Document document = null;
        try {
            document = saxReader.read(Resources.readFileAsInputStream(configLocation));
        } catch (DocumentException e) {
            e.printStackTrace();
        }
        Element rootElement = document.getRootElement();
        loadXmlBeanDefinitions(rootElement);
        String annotationScanBasePackage = checkAnnotationScanOpenAndGetBasePackage(rootElement);
        if (StringUtil.isNotBlank(annotationScanBasePackage)) {
            loadAndMergeAnnotationBeanDefinitions(annotationScanBasePackage);
        }
        return beanDefinitionMap;
    }

    /**
     * 加载并合并bean定义映射
     *
     * @param annotationScanBasePackage
     */
    private void loadAndMergeAnnotationBeanDefinitions(String annotationScanBasePackage) {
        Map<String, BeanDefinition> annotationBeanDefinitionMap = new AnnotationBeanDefinitionLoader(annotationScanBasePackage).load();
        if (annotationBeanDefinitionMap.size() > 0) {
            for (Map.Entry<String, BeanDefinition> annotationBeanDefinitionEntry : annotationBeanDefinitionMap.entrySet()) {
                if (beanDefinitionMap.containsKey(annotationBeanDefinitionEntry.getKey())) {
                    throw new RuntimeException("the bean id " + annotationBeanDefinitionEntry.getKey() + "have been registered");
                }
                beanDefinitionMap.put(annotationBeanDefinitionEntry.getKey(), annotationBeanDefinitionEntry.getValue());
            }
        }
    }

    /**
     * 加载xml中定义的bean
     *
     * @param rootElement
     */
    private void loadXmlBeanDefinitions(Element rootElement) {
        List<Element> beanElementList = rootElement.selectNodes("//bean");
        if (CollectionUtil.isEmpty(beanElementList)) {
            return;
        }
        beanElementList.stream().forEach(beanElement -> parseBeanElement(beanElement));
    }

    /**
     * 解析bean标签
     *
     * @param beanElement
     */
    private void parseBeanElement(Element beanElement) {
        String beanName = beanElement.attributeValue("id"),
                classPath = beanElement.attributeValue("class"),
                initMethod = beanElement.attributeValue("init-method"),
                initBeanFlag = beanElement.attributeValue("init-bean");
        Class beanClass = ReflectUtil.getClass(classPath);
        List<Class> interfaces = ReflectUtil.getInterfaces(beanClass);
        if (StringUtil.isAnyBlank(beanName, classPath)) {
            throw new RuntimeException("the attribute id and class can not be empty");
        }
        BeanDefinition beanDefinition = new BeanDefinition();
        beanDefinition.setBeanName(beanName);
        beanDefinition.setBeanClass(beanClass);
        beanDefinition.setInterfaces(interfaces);
        beanDefinition.setInitMethodName(initMethod);
        List<Element> propertyElements = beanElement.selectNodes("property");
        parseBeanPropertyElements(beanDefinition, propertyElements);
        beanDefinitionMap.put(beanName, beanDefinition);
    }

    /**
     * 解析bean属性标签列表
     *
     * @param beanDefinition
     * @param propertyElements
     */
    private void parseBeanPropertyElements(BeanDefinition beanDefinition, List<Element> propertyElements) {
        if (CollectionUtil.isEmpty(propertyElements)) {
            beanDefinition.setFieldDefinitions(new ArrayList<>());
            return;
        }
        List<FieldDefinition> fieldDefinitionList = propertyElements.stream().map(propertyElement -> parseBeanPropertyElement(beanDefinition, propertyElement)).collect(Collectors.toList());
        beanDefinition.setFieldDefinitions(fieldDefinitionList);
    }

    /**
     * 解析bean属性标签
     *
     * @param beanDefinition
     * @param propertyElement
     */
    private FieldDefinition parseBeanPropertyElement(BeanDefinition beanDefinition, Element propertyElement) {
        String fieldName = propertyElement.attributeValue("name"),
                ref = propertyElement.attributeValue("ref"),
                value = propertyElement.attributeValue("value");
        if (StringUtil.isNoneBlank(ref, value)) {
            throw new RuntimeException("the property attribute ref and value can not coexist");
        }
        FieldDefinition fieldDefinition = new FieldDefinition();
        fieldDefinition.setFieldName(fieldName);
        fieldDefinition.setRef(ref);
        fieldDefinition.setValue(value);
        fieldDefinition.setFieldClass(getFieldClass(beanDefinition.getBeanClass(), fieldName));
        fieldDefinition.setInjectType(getInjectType(beanDefinition.getBeanClass(), fieldName));
        fieldDefinition.setAutowired(false);
        return fieldDefinition;
    }

    /**
     * 获取属性类型
     *
     * @param beanClass
     * @param fieldName
     * @return
     */
    private Class getFieldClass(Class beanClass, String fieldName) {
        try {
            Field field = beanClass.getField(fieldName);
            return field.getType();
        } catch (NoSuchFieldException e) {
            Method[] declaredMethods = beanClass.getDeclaredMethods();
            for (int i = 0; i < declaredMethods.length; i++) {
                if (declaredMethods[i].getName().equalsIgnoreCase("set" + fieldName)) {
                    Class<?>[] parameterTypes = declaredMethods[i].getParameterTypes();
                    if (parameterTypes == null || parameterTypes.length == 0) {
                        throw new RuntimeException("there are no field or set method matched " + fieldName);
                    }
                    return parameterTypes[0];
                }
            }
        }
        throw new RuntimeException("there are no field or set method matched " + fieldName);
    }

    /**
     * 获取注入类型
     *
     * @param beanClass
     * @param fieldName
     * @return
     */
    private InjectType getInjectType(Class beanClass, String fieldName) {
        try {
            if (beanClass.getField(fieldName) != null) {
                return InjectType.FIELD;
            }
            return InjectType.SET;
        } catch (NoSuchFieldException e) {
            return InjectType.SET;
        }
    }

    /**
     * 查看是否开启了注解扫描
     *
     * @param rootElement
     * @return
     */
    public String checkAnnotationScanOpenAndGetBasePackage(Element rootElement) {
        List<Element> annotationElements = rootElement.selectNodes("//component-scan");
        if (CollectionUtil.isEmpty(annotationElements)) {
            return null;
        }
        if (annotationElements.size() > 1) {
            throw new RuntimeException("the annotation-scan element only can defined for once");
        }
        String basePackage = annotationElements.get(0).attributeValue("base-package");
        if (StringUtil.isBlank(basePackage)) {
            throw new RuntimeException("the base-package can not be blank");
        }
        return basePackage;
    }

}

下面是我们的配置文件示例,以供参考:

<beans>

    <bean id="order" class="com.rubin.spring.context.beans.Order">
        <property name="id" value="1"></property>
        <property name="orderNo" value="O00001"></property>
        <property name="address" ref="address"></property>
    </bean>

    <bean id="address" class="com.rubin.spring.context.beans.Address" init-method="init">
        <property name="province" value="北京市"></property>
        <property name="city" value="北京市"></property>
        <property name="region" value="丰台区"></property>
        <property name="address" value="东铁匠营街道666号"></property>
    </bean>

    <component-scan base-package="com.rubin.spring.context.beans"></component-scan>

</beans>

注解扫描类定义加载器

此加载器会扫描我们的@Component注解(以及子注解,例如:@Service)标注的bean,并解析成bean定义实体返回:

/**
 * 注解bean定义加载器
 */
public class AnnotationBeanDefinitionLoader implements BeanDefinitionLoader {

    private String annotationScanBasePackage;

    private Map<String, BeanDefinition> beanDefinitionMap;

    public AnnotationBeanDefinitionLoader(String annotationScanBasePackage) {
        this.annotationScanBasePackage = annotationScanBasePackage;
        this.beanDefinitionMap = new ConcurrentHashMap<>();
    }

    @Override
    public Map<String, BeanDefinition> load() {
        Set<Class> beanClasses = ReflectUtil.getAllObjectClassWithAnnotationOrSubAnnotation(annotationScanBasePackage, Component.class);
        if (!CollectionUtil.isEmpty(beanClasses)) {
            beanClasses.stream().forEach(beanClass -> parseBeanClass(beanClass));
        }
        return beanDefinitionMap;
    }

    /**
     * 解析beanClass为BeanDefinition
     *
     * @param beanClass
     */
    private void parseBeanClass(Class beanClass) {
        String beanName = getBeanName(beanClass);
        BeanDefinition beanDefinition = new BeanDefinition();
        beanDefinition.setBeanName(beanName);
        beanDefinition.setBeanClass(beanClass);
        beanDefinition.setInterfaces(ReflectUtil.getInterfaces(beanClass));
        beanDefinition.setInitMethodName(StringUtil.EMPTY);
        beanDefinition.setFieldDefinitions(new ArrayList<>());
        Field[] declaredFields = beanClass.getDeclaredFields();
        for (int i = 0; i < declaredFields.length; i++) {
            parseBeanField(beanDefinition, declaredFields[i]);
        }
        beanDefinitionMap.put(beanName, beanDefinition);
    }

    /**
     * 解析依赖注入字段
     *
     * @param beanDefinition
     * @param declaredField
     */
    private void parseBeanField(BeanDefinition beanDefinition, Field declaredField) {
        if (!declaredField.isAnnotationPresent(Autowired.class) && !declaredField.isAnnotationPresent(Value.class)) {
            return;
        }
        FieldDefinition fieldDefinition = new FieldDefinition();
        fieldDefinition.setFieldName(declaredField.getName());
        fieldDefinition.setFieldClass(declaredField.getType());
        if (declaredField.isAnnotationPresent(Autowired.class)) {
            fieldDefinition.setAutowired(true);
            fieldDefinition.setValue(StringUtil.EMPTY);
        } else {
            fieldDefinition.setAutowired(false);
            Value valueAnnotation = declaredField.getAnnotation(Value.class);
            String value = valueAnnotation.value();
            if (StringUtil.isBlank(value)) {
                throw new RuntimeException("the @Value annotation can not be blank");
            }
            fieldDefinition.setValue(value);
        }
        fieldDefinition.setInjectType(InjectType.FIELD);
        fieldDefinition.setRef(StringUtil.EMPTY);
        beanDefinition.getFieldDefinitions().add(fieldDefinition);
    }

    /**
     * 获取bean的名称
     *
     * @param beanClass
     * @return
     */
    private String getBeanName(Class<?> beanClass) {
        if (beanClass.isAnnotationPresent(Component.class)) {
            Component component = beanClass.getAnnotation(Component.class);
            if (StringUtil.isNotBlank(component.value())) {
                return component.value();
            }
        }
        Annotation subAnnotation = ReflectUtil.getSubAnnotation(beanClass, Component.class);
        Method[] declaredMethods = subAnnotation.getClass().getDeclaredMethods();
        for (int i = 0; i < declaredMethods.length; i++) {
            if (declaredMethods[i].getName().equals("value")) {
                try {
                    String beanName = String.valueOf(declaredMethods[i].invoke(subAnnotation, new Object[]{}));
                    if (StringUtil.isNotBlank(beanName)) {
                        return beanName;
                    }
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                } catch (InvocationTargetException e) {
                    e.printStackTrace();
                }
            }
        }
        return BeanUtil.getDefaultBeanName(beanClass.getSimpleName());
    }

}

至此,我们的简单IoC容器就完成了。下面我们就测试一下。

测试

我们来定义四个实体,实体要有循环依赖,注解标识和普通实体,测试我们容器的各种加载情况:

@Data
@Component("a")
public class A implements Serializable {

    private static final long serialVersionUID = -6065196434455688155L;

    @Autowired
    private B b;

    @Value("aName")
    private String name;

    @Override
    public String toString() {
        return "A{" +
                "b=" + b +
                '}';
    }

}
@Data
@Service("b")
public class B implements Serializable {

    private static final long serialVersionUID = 4771041380171984169L;

    @Autowired
    private A a;

    @Value("666")
    private Integer number;

    @Override
    public String toString() {
        return "B{" +
                "a=" + a +
                '}';
    }

}
@Data
public class Address implements Serializable {

    private static final long serialVersionUID = 7707632678157589784L;

    private String province;

    private String city;

    private String region;

    private String address;

    public void init() {
        System.out.println("执行Address的init-method");
    }

}
@Data
public class Order implements Serializable {

    private static final long serialVersionUID = 8046447974405633199L;

    private Integer id;

    private String orderNo;

    private Address address;

}

然后定义我们的配置文件:

<beans>

    <bean id="order" class="com.rubin.spring.context.beans.Order">
        <property name="id" value="1"></property>
        <property name="orderNo" value="O00001"></property>
        <property name="address" ref="address"></property>
    </bean>

    <bean id="address" class="com.rubin.spring.context.beans.Address" init-method="init">
        <property name="province" value="北京市"></property>
        <property name="city" value="北京市"></property>
        <property name="region" value="丰台区"></property>
        <property name="address" value="东铁匠营街道666号"></property>
    </bean>

    <component-scan base-package="com.rubin.spring.context.beans"></component-scan>

</beans>

最后,写我们的测试逻辑:

/**
 * 容器测试类
 * Created by rubin on 4/3/21.
 */
public class ApplicationContextTest {

    /**
     * 按名称获取bean测试
     */
    @Test
    public void getBeanByNameTest() {
        ApplicationContext applicationContext = new ClassPathXmlApplicationContext("beans.xml");
        Order order = (Order) applicationContext.getBean("order");
        Address address = (Address) applicationContext.getBean("address");
        A a = (A) applicationContext.getBean("a");
        B b = (B) applicationContext.getBean("b");
        Assert.assertNotNull(order);
        Assert.assertNotNull(address);
        Assert.assertNotNull(a);
        Assert.assertNotNull(b);
    }

    /**
     * 按类型获取bean测试
     */
    @Test
    public void getBeanByTypeTest() {
        ApplicationContext applicationContext = new ClassPathXmlApplicationContext("beans.xml");
        Order order = (Order) applicationContext.getBean(Order.class);
        Address address = (Address) applicationContext.getBean(Address.class);
        A a = (A) applicationContext.getBean(A.class);
        B b = (B) applicationContext.getBean(B.class);
        Assert.assertNotNull(order);
        Assert.assertNotNull(address);
        Assert.assertNotNull(a);
        Assert.assertNotNull(b);
    }

}

测试通过,我们的IoC简易容器就定义完成了。

附件

附上源码以供小伙伴参考:my-spring-demo

本作品采用 知识共享署名 4.0 国际许可协议 进行许可
标签: Spring
最后更新:2022年 6月 9日

RubinChu

一个快乐的小逗比~~~

打赏 点赞
下一篇 >

文章评论

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
取消回复
文章目录
  • 前言
  • 手写IoC框架
    • 什么是IoC/DI?
    • 实现思路分析
    • 项目准备
    • 初始化容器对象
    • 初始化上下文对象
    • bean定义加载器
      • 配置文件类定义加载器
      • 注解扫描类定义加载器
    • 测试
  • 附件
最新 热点 随机
最新 热点 随机
问题记录之Chrome设置屏蔽Https禁止调用Http行为 问题记录之Mac设置软链接 问题记录之JDK8连接MySQL数据库失败 面试系列之自我介绍 面试总结 算法思维
Kafka高级特性之稳定性 数据结构之图 MyBatis之配置文件详解 面试总结 Dubbo之应用案例 RabbitMQ之集群与运维

COPYRIGHT © 2021 rubinchu.com. ALL RIGHTS RESERVED.

Theme Kratos Made By Seaton Jiang

京ICP备19039146号-1