Java实现国密Sm3算法

虽然国密算法SM3的官方文档只有短短四五页,但实现起来涉及的细节还是挺多的。特此,Milo将开发心得,以及源码和注解做个分享,希望读者能够轻松理解SM3。 + 预备知识 - Java基本数据类型大小 byte:1个字节;short:2个字节;char:2个字节;int:4个字节;long:8个字节等...

虽然国密算法SM3的官方文档只有短短四五页,但实现起来涉及的细节还是挺多的。特此,Milo将开发心得,以及源码和注解做个分享,希望读者能够轻松理解SM3。

  • 预备知识

    • Java基本数据类型大小

      byte:1个字节;short:2个字节;char:2个字节;int:4个字节;long:8个字节等等。

      十六进制:一个数代表4个二进制位。所以需要拿捏好数据长短。

    • Ascii码

      有区分十进制、八进制、十六进制。

      官方文档样例中使用的都是十六进制,“abc”=“616263”,而不是十进制Ascii码“979899”。计算和操作过程中,要留意数据类型转换的问题。

    • 左移和循环左移

      官方文档中提及的移位计算均为循环左移。即移出的高位放到该数的低位。

    • 左补0

      在转换数据类型时,除了关注溢出问题、Ascii码问题,还要关注转换结果的长度问题。SM3算法中,若二进制长度不足需要补全,否则拼接结果时会出现意想不到的错误。

  • 辅助方法

    • 辅助变量

      辅助变量主要定义了参数的初始值,循环边界变量,字长,位长,执行轮次等信息。由于官方文档中使用的单位是字,且字的长度是32位,因此,Milo选用int基本数据类型作为存储单位。值得注意的是数组T的初始化,Milo选择在构造函数中进行,读者也可以选择在静态代码块中执行初始化命令。

      private static final int[] IV = {0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e};
      private static final Integer T_LEN = 64;
      private static final int[] T = new int[T_LEN];
      // 也可以使用静态代码块,初始化顺序:静态代码块->构造函数
      public Sm3() {
        // 初始化T数组
        for (int j = 0; j < T_LEN; ++j) {
          T[j] = j < 16 ? Integer.parseInt("79cc4519", 16) : Integer.parseInt("7a879d8a", 16);
        }
      }
      private static final int BIT_LEN = 8;
      private static final int PADDING_LEN = 64;
      private static final int GROUP_NUM = 132;
      private static final int FIRST_GROUP_END = 15;
      private static final int SECOND_GROUP_START = 16;
      private static final int SECOND_GROUP_END = 67;
      private static final int THIRD_GROUP_START = 68;
      private static final int THIRD_GROUP_END = 131;
      private static final int BIT_PER_WORD = 32;
      private static final int COMPRESSION_TIMES = 64;
    • 辅助函数

      辅助函数除了实现官方文档中指定的函数外,Milo还实现了字符串左补零和获取消息比特串的函数。后续主要的返回结果都是字符串类型。

      计算过程中,主要注意的点有两个方面:一是左移操作,不能使用<<进行左移,而应该使用Integer.rotateLeft()方法进行循环左移;二是十六进制转整型时,要时刻考虑着数值溢出的问题,具体位置见代码注释。

      /**
      * 第一个布尔函数FFj。
      * 官方文档:式中X,Y,Z 为字。但操作是32比特逻辑运算。
      * 注意:此处没做j的校验,因为调用时限制j的范围。
      *
      * @return 官方文档中,1字 = 32比特,做32bit逻辑运算结果应该是4字节,所以选择int基本数据类型。
      */
      private int ff(int x, int y, int z, int j) {
        return j < 16 ? x ^ y ^ z : (x & y) | (x & z) | (y & z);
      }
      
      /**
      * 第二个布尔函数GGj。
      *
      * @return 同上FFj。
      */
      private int gg(int x, int y, int z, int j) {
        return j < 16 ? x ^ y ^ z : (x & y) | (~x & z);
      }
      
      /**
      * 压缩函数中的置换函数。
      *
      * @param x x为字。
      * @return 做32位逻辑运算,选择返回int。
      */
      private int p0(int x) {
        return x ^ Integer.rotateLeft(x, 9) ^ Integer.rotateLeft(x, 17);
      }
      
      /**
      * 消息扩展中的置换函数。
      *
      * @param x x为字。
      * @return 同上p0。
      */
      private int p1(int x) {
        return x ^ Integer.rotateLeft(x, 15) ^ Integer.rotateLeft(x, 23);
      }
      
      /**
      * 左补0
      *
      * @param src    初始值
      * @param length 最终结果长度
      */
      private String leftSupplementZero(StringBuilder src, int length) {
        while (src.length() < length) {
          src.insert(0, "0");
        }
        return src.toString();
      }
      
      /**
      * 获取消息的比特串
      * String明文转Hex字符串
      *
      * @param message 原始消息字符串
      * @return 比特串
      */
      private String getBinaryMessage(String message) {
        StringBuilder stringBuilder = new StringBuilder();
        for (char ch : message.toCharArray()) {
          // 一个个字符处理,不会超过int范围
          StringBuilder convert = new StringBuilder(Integer.toBinaryString(Integer.parseInt(Integer.toHexString(ch), 16)));
          // 左补0,补齐8位
          leftSupplementZero(convert, BIT_LEN);
          stringBuilder.append(convert);
        }
        return stringBuilder.toString();
      }
  • 主要过程

    • 填充

      填充时关注的点有两个:首先,使用的是二进制的消息。其次,在求k时要考虑长度l可能小于448,也可能大于448,其最长不超过2^64。

      	/**
        * 假设消息m的长度为l比特。
        * 首先将比特“1”添加到消息的末尾
        * 再添加k个“0”,k是满 足l + 1 + k ≡ 448mod512 的最小的非负整数。
        * 然后再添加一个64位比特串,该比特串是长度l的二进制表示。
        *
        * @param message 消息m
        * @return 填充后的消息m′的比特长度为512的倍数。
        */
      private Map<String, Object> filling(String message) {
        Map<String, Object> result = new LinkedHashMap<>();
        // 获取二进制消息
        String binaryMessage = getBinaryMessage(message);
        // 获取消息m的长度l
        int l = binaryMessage.length();
        // 再添加1
        StringBuilder filledResult = new StringBuilder(binaryMessage + "1");
        // 添加k个0
        int k = 448 - (l + 1) % 512;
        result.put("k", k);
        StringBuilder binaryL = new StringBuilder(Integer.toBinaryString(l));
        while (k > 0) {
          filledResult.append("0");
          --k;
        }
        // 添加64位比特串
        // 左补0
        leftSupplementZero(binaryL, PADDING_LEN);
        filledResult.append(binaryL);
        result.put("l", l);
        result.put("filledResult", filledResult);
        return result;
      }	
    • 迭代压缩

      迭代压缩时,此处由于压缩函数compression传递的参数时数组,所以使用了二维数组b和v。简单理解是,对于b,填充后的消息会被分成n组,每组(段)消息又会被切分成132个字;对于v,压缩函数中,每次都会循环计算63次,每次都会得到一个v,这个中间值v,会在下一次循环中被拆分成ABCDEFGH,共8段。v(0)是预先设置好的,同样是8段。

      此处会出现数值溢出的问题,Milo查阅源码后采用Integer.parseUnsignedInt()方法。它在数值过大时,自动采用long类型来进行运算。

      /**
      * 迭代压缩
      *
      * @param fillingResult 填充后的结果,包括原始消息长度l,填充0的个数k,消息后的消息m'
      * @return vn 最终的杂凑值
      */
      private String iterativeCompression(Map<String, Object> fillingResult) {
        // 获取l,k计算n
        int l = (int) fillingResult.get("l"), k = (int) fillingResult.get("k"), n = (l + k + 65) / 512;
        // 存储分组
        String[] mGroups = new String[n];
        // 存储扩展消息
        String[][] b = new String[n][132];
        // 获取填充后的消息,进行分组
        String filledResult = ((StringBuilder) fillingResult.get("filledResult")).toString();
        for (int i = 0; i < n; ++i) {
          // 分组,每组512比特。起始位置:0,512,1024;结束位置:511,1023,1535
          mGroups[i] = filledResult.substring(i * 512, i * 512 + 512);
          b[i] = messageExpand(mGroups[i]);
        }
        // 迭代数组。结果取v(n)的十六进制ascii码
        int[][] v = new int[n + 1][];
        // 初始化v(0)
        v[0] = IV;
        for (int j = 0; j < n; ++j) {
          int[] bj = new int[132];
          for (int x = 0; x < bj.length; ++x) {
            // 处理越界问题
            bj[x] = Integer.parseUnsignedInt(b[j][x], 2);
          }
          v[j + 1] = compression(v[j], bj);
        }
        StringBuilder result = new StringBuilder();
        for (int y = 0; y < v[n].length; ++y) {
          //转十六进制ascii码
          result.append(Integer.toHexString(v[n][y]));
      }
        return result.toString();
      }
  • 完整代码及注释(已通过官方文档两个样例测试)

/**
* SM3杂凑算法。
* 使用的是16进制ascii表。
* 统一使用byte数组存储。
* 概述:对长度为l(l < 264) 比特的消息m,SM3杂凑算法经过填充、迭代压缩,生成杂凑值,杂凑值长度为256比特。
*/
public static class Sm3 {
  // 也可以使用静态代码块,初始化顺序:静态代码块->构造函数
  public Sm3() {
    // 初始化T数组
    for (int j = 0; j < T_LEN; ++j) {
      T[j] = j < 16 ? Integer.parseInt("79cc4519", 16) : Integer.parseInt("7a879d8a", 16);
    }
  }
  
  /**
  * 初始值iv,长度:4 * 8 * 8 = 256 bit。
  * 1 byte = 8 bit。
  * 16进制:2个数 = 8bit。
  * 官方文档:1字 = 32比特,所以此处使用int类型。
  * 如果使用其他基本数据类型,要注意数据类型大小,并拆分iv。
  */
  private static final int[] IV = {0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e};
  
  /**
  * 常量Tj,长度64。
  * 内容长度:4 * 8 = 32 bit,用int足够。
  * 在构造方法中初始化。
  */
  private static final Integer T_LEN = 64;
  private static final int[] T = new int[T_LEN];
  /**
  * 获得消息比特串时的位长
  */
  private static final int BIT_LEN = 8;
  /**
  * 填充消息时,末尾添加的0比特串长度
  */
  private static final int PADDING_LEN = 64;
  /**
  * 消息扩展时,拆分成132个字
  */
  private static final int GROUP_NUM = 132;
  /**
  * 消息扩展第一组:0~15
  */
  private static final int FIRST_GROUP_END = 15;
  /**
  * 消息扩展第二组:16~67
  */
  private static final int SECOND_GROUP_START = 16;
  private static final int SECOND_GROUP_END = 67;
  /**
  * 消息扩展第三组:68~131
  */
  private static final int THIRD_GROUP_START = 68;
  private static final int THIRD_GROUP_END = 131;
  /**
  * 每个字的长度:32比特
  */
  private static final int BIT_PER_WORD = 32;
  /**
  * 压缩函数执行轮次
  */
  private static final int COMPRESSION_TIMES = 64;
  
  /**
  * 第一个布尔函数FFj。
  * 官方文档:式中X,Y,Z 为字。但操作是32比特逻辑运算。
  * 注意:此处没做j的校验,因为调用时限制j的范围。
  *
  * @return 官方文档中,1字 = 32比特,做32bit逻辑运算结果应该是4字节,所以选择int基本数据类型。
  */
  private int ff(int x, int y, int z, int j) {
    return j < 16 ? x ^ y ^ z : (x & y) | (x & z) | (y & z);
  }
  
  /**
  * 第二个布尔函数GGj。
  *
  * @return 同上FFj。
  */
  private int gg(int x, int y, int z, int j) {
    return j < 16 ? x ^ y ^ z : (x & y) | (~x & z);
  }
  
  /**
  * 压缩函数中的置换函数。
  *
  * @param x x为字。
  * @return 做32位逻辑运算,选择返回int。
  */
  private int p0(int x) {
    return x ^ Integer.rotateLeft(x, 9) ^ Integer.rotateLeft(x, 17);
  }
  
  /**
  * 消息扩展中的置换函数。
  *
  * @param x x为字。
  * @return 同上p0。
  */
  private int p1(int x) {
    return x ^ Integer.rotateLeft(x, 15) ^ Integer.rotateLeft(x, 23);
  }
  
  /**
  * 左补0
  *
  * @param src    初始值
  * @param length 最终结果长度
  */
  private String leftSupplementZero(StringBuilder src, int length) {
    while (src.length() < length) {
      src.insert(0, "0");
    }
    return src.toString();
  }
  
  /**
  * 获取消息的比特串
  * String明文转Hex字符串
  *
  * @param message 原始消息字符串
  * @return 比特串
  */
  private String getBinaryMessage(String message) {
    StringBuilder stringBuilder = new StringBuilder();
    for (char ch : message.toCharArray()) {
      // 一个个字符处理,不会超过int范围
      StringBuilder convert = new StringBuilder(Integer.toBinaryString(Integer.parseInt(Integer.toHexString(ch), 16)));
      // 左补0,补齐8位
      leftSupplementZero(convert, BIT_LEN);
      stringBuilder.append(convert);
    }
    return stringBuilder.toString();
  }
  
  /**
  * 假设消息m的长度为l比特。
  * 首先将比特“1”添加到消息的末尾
  * 再添加k 个“0”,k是满 足l + 1 + k ≡ 448mod512 的最小的非负整数。
  * 然后再添加一个64位比特串,该比特串是长度l的二进制表示。
  *
  * @param message 消息m
  * @return 填充后的消息m′的比特长度为512的倍数。
  */
  private Map<String, Object> filling(String message) {
    Map<String, Object> result = new LinkedHashMap<>();
    // 获取二进制消息
    String binaryMessage = getBinaryMessage(message);
    // 获取消息m的长度l
    int l = binaryMessage.length();
    // 再添加1
    StringBuilder filledResult = new StringBuilder(binaryMessage + "1");
    // 添加k个0
    int k = 448 - (l + 1) % 512;
    result.put("k", k);
    StringBuilder binaryL = new StringBuilder(Integer.toBinaryString(l));
    while (k > 0) {
      filledResult.append("0");
      --k;
    }
    // 添加64位比特串
    // 左补0
    leftSupplementZero(binaryL, PADDING_LEN);
    filledResult.append(binaryL);
    result.put("l", l);
    result.put("filledResult", filledResult);
    return result;
  }
  
  /**
  * 消息扩展
  *
  * @param bi 512比特的消息分组Bi
  * @return 132个字W0, W1, ..., W63, W'0,W'1,...,W'63
  */
  private String[] messageExpand(String bi) {
    // 第一组0~15个字:将512比特分成16个字(此处表明:1字 = 32比特)
    int[] w = new int[GROUP_NUM];
    for (int i = 0; i <= FIRST_GROUP_END; ++i) {
      // 开头:0,32,64;结尾:31,63,95
      w[i] = Integer.parseUnsignedInt(bi.substring(i * 32, i * 32 + 32), 2);
    }
    // 第二组16~67个字
    for (int j = SECOND_GROUP_START; j <= SECOND_GROUP_END; ++j) {
      w[j] = p1(w[j - 16] ^ w[j - 9] ^ Integer.rotateLeft(w[j - 3], 15)) ^ Integer.rotateLeft(w[j - 13], 7) ^ w[j - 6];
    }
    // 第三组68~131个字
    for (int k = THIRD_GROUP_START; k <= THIRD_GROUP_END; ++k) {
      w[k] = w[k - 68] ^ w[k - 64];
    }
    // 转二进制输出结果
    String[] result = new String[GROUP_NUM];
    for (int x = 0; x < GROUP_NUM; ++x) {
      result[x] = leftSupplementZero(new StringBuilder(Integer.toBinaryString(w[x])), BIT_PER_WORD);
    }
    return result;
  }
  
  /**
  * 压缩函数V(i+1) = CF(Vi,Bi)
  *
  * @param vi 上一次结果 256比特
  * @param bi 第i组 512比特 132个分组
  * @return v(i + 1) 256比特
  */
  private int[] compression(int[] vi, int[] bi) {
    int a = vi[0], b = vi[1], c = vi[2], d = vi[3], e = vi[4], f = vi[5], g = vi[6], h = vi[7];
    System.out.println("j       A       B       C       D       E       F       G       H");
    System.out.println(" \t" + Integer.toHexString(a) + "\t" + Integer.toHexString(b) + "\t" + Integer.toHexString(c) + "\t" + Integer.toHexString(d) + "\t" + Integer.toHexString(e) + "\t" + Integer.toHexString(f) + "\t" + Integer.toHexString(g) + "\t" + Integer.toHexString(h));
    for (int j = 0; j < COMPRESSION_TIMES; ++j) {
      int ss1 = Integer.rotateLeft(Integer.rotateLeft(a, 12) + e + Integer.rotateLeft(T[j], j), 7),
      ss2 = ss1 ^ Integer.rotateLeft(a, 12),
      tt1 = ff(a, b, c, j) + d + ss2 + bi[68 + j],
      tt2 = gg(e, f, g, j) + h + ss1 + bi[j];
      d = c;
      c = Integer.rotateLeft(b, 9);
      b = a;
      a = tt1;
      h = g;
      g = Integer.rotateLeft(f, 19);
      f = e;
      e = p0(tt2);
      System.out.println(j + "\t" + Integer.toHexString(a) + "\t" + Integer.toHexString(b) + "\t" + Integer.toHexString(c) + "\t" + Integer.toHexString(d) + "\t" + Integer.toHexString(e) + "\t" + Integer.toHexString(f) + "\t" + Integer.toHexString(g) + "\t" + Integer.toHexString(h));
    }
    int[] result = new int[8];
    // 最终结果还需要做一次32位异或操作
    result[0] = a ^ vi[0];
    result[1] = b ^ vi[1];
    result[2] = c ^ vi[2];
    result[3] = d ^ vi[3];
    result[4] = e ^ vi[4];
    result[5] = f ^ vi[5];
    result[6] = g ^ vi[6];
    result[7] = h ^ vi[7];
    return result;
  }
  
  /**
  * 迭代压缩
  *
  * @param fillingResult 填充后的结果,包括原始消息长度l,填充0的个数k,消息后的消息m'
  * @return vn 最终的杂凑值
  */
  private String iterativeCompression(Map<String, Object> fillingResult) {
    // 获取l,k计算n
    int l = (int) fillingResult.get("l"), k = (int) fillingResult.get("k"), n = (l + k + 65) / 512;
    // 存储分组
    String[] mGroups = new String[n];
    // 存储扩展消息
    String[][] b = new String[n][132];
    // 获取填充后的消息,进行分组
    String filledResult = ((StringBuilder) fillingResult.get("filledResult")).toString();
    for (int i = 0; i < n; ++i) {
      // 分组,每组512比特。起始位置:0,512,1024;结束位置:511,1023,1535
      mGroups[i] = filledResult.substring(i * 512, i * 512 + 512);
      b[i] = messageExpand(mGroups[i]);
    }
    // 迭代数组。结果取v(n)的十六进制ascii码
    int[][] v = new int[n + 1][];
    // 初始化v(0)
    v[0] = IV;
    for (int j = 0; j < n; ++j) {
      int[] bj = new int[132];
      for (int x = 0; x < bj.length; ++x) {
        // 处理越界问题
        bj[x] = Integer.parseUnsignedInt(b[j][x], 2);
      }
      v[j + 1] = compression(v[j], bj);
    }
    StringBuilder result = new StringBuilder();
    for (int y = 0; y < v[n].length; ++y) {
      //转十六进制ascii码
      result.append(Integer.toHexString(v[n][y]));
    }
    return result.toString();
  }

  /**
  * 加密算法。
  * 外部调用,不能声明为static
  *
  * @param src 源文
  * @return 密文
  */
  public String cipher(String src) {
    // 填充
    Map<String, Object> fillingResult = filling(src);
    // 迭代压缩
    return iterativeCompression(fillingResult);
  }
}

转载须知

本文欢迎转载,但请务必保留原文链接,谢谢!

商业合作请联系邮箱:choibunbing@gmail.com