diff --git a/go/gmssl/sm3/sm3.go b/go/gmssl/sm3/sm3.go new file mode 100644 index 00000000..3f43bf37 --- /dev/null +++ b/go/gmssl/sm3/sm3.go @@ -0,0 +1,57 @@ +package sm3 + +import ( + "gmssl" + "hash" +) + +type digest struct { + ctx *gmssl.DigestContext +} + +func New() hash.Hash { + d := new(digest) + ctx, err := gmssl.NewDigestContext("SM3", nil) + if err != nil { + return nil + } + d.ctx = ctx + return d +} + +func (d *digest) BlockSize() int { + ret, err := gmssl.GetDigestBlockSize("SM3") + if err != nil { + return 0 + } + return ret +} + +func (d *digest) Size() int { + ret, err := gmssl.GetDigestLength("SM3") + if err != nil { + return 0 + } + return ret +} + +func (d *digest) Reset() { + err := d.ctx.Reset() + if err != nil { + // do something? + } +} + +func (d *digest) Write(p []byte) (int, error) { + err := d.ctx.Update(p) + return len(p), err +} + +func (d *digest) Sum(in []byte) []byte { + d.ctx.Update(in) + ret, err := d.ctx.Final() + if err != nil { + return nil + } + return ret +} diff --git a/go/gmssltest/gmssltest.go b/go/gmssltest/gmssltest.go index 886455de..3f4e2581 100644 --- a/go/gmssltest/gmssltest.go +++ b/go/gmssltest/gmssltest.go @@ -2,6 +2,7 @@ package main import ( "gmssl" + "gmssl/sm3" "fmt" ) @@ -29,37 +30,41 @@ func main() { fmt.Println(""); /* sm3 digest */ - sm3, err := gmssl.NewDigestContext("SM3", nil) + sm3ctx, err := gmssl.NewDigestContext("SM3", nil) if err != nil { fmt.Println(err) return } - if err := sm3.Update([]byte("a")); err != nil { + if err := sm3ctx.Update([]byte("a")); err != nil { fmt.Println(err) return } - if err := sm3.Update([]byte("bc")); err != nil { + if err := sm3ctx.Update([]byte("bc")); err != nil { fmt.Println(err) return } - sm3digest, err := sm3.Final() + sm3digest, err := sm3ctx.Final() if err != nil { fmt.Println(err) return } fmt.Printf("sm3(\"abc\") = %x\n", sm3digest) + sm3hash := sm3.New() + sm3hash.Write([]byte("abc")) + fmt.Printf("sm3(\"abc\") = %x\n", sm3hash.Sum(nil)) + /* hmac-sm3 */ hmac_sm3, err := gmssl.NewHMACContext("SM3", nil, []byte("this is the key")) if err != nil { fmt.Println(err) return } - if err := hmac_sm3.Update([]byte("hello")); err != nil { + if err := hmac_sm3.Update([]byte("ab")); err != nil { fmt.Println(err) return } - if err := hmac_sm3.Update([]byte("world")); err != nil { + if err := hmac_sm3.Update([]byte("c")); err != nil { fmt.Println(err) return } @@ -68,7 +73,7 @@ func main() { fmt.Println(err) return } - fmt.Printf("hmac-sm3() = %x\n", mactag) + fmt.Printf("hmac-sm3(\"abc\") = %x\n", mactag) /* generate random key */ keylen, err := gmssl.GetCipherKeyLength("SMS4") @@ -113,7 +118,6 @@ func main() { ciphertext := make([]byte, 0, len(ciphertext1) + len(ciphertext2)) ciphertext = append(ciphertext, ciphertext1...) ciphertext = append(ciphertext, ciphertext2...) - fmt.Printf("sms4(\"hello\") = %x\n", ciphertext) /* decrypt */ decryptor, err := gmssl.NewCipherContext("SMS4", nil, key, iv, false) @@ -134,7 +138,8 @@ func main() { plaintext := make([]byte, 0, len(plaintext1) + len(plaintext2)) plaintext = append(plaintext, plaintext1...) plaintext = append(plaintext, plaintext2...) - fmt.Println(string(plaintext)) + + fmt.Printf("sms4(\"%s\") = %x\n", plaintext, ciphertext) }