#!/usr/bin/env python3
"""
SHA-224 hash function implementation in Python
Based on the Secure Hash Algorithm 2 as defined in FIPS 180-4

Author: SHA224.com
License: MIT
"""

import struct
import binascii
from typing import Union, List, Tuple, Optional


class SHA224:
    """
    SHA-224 hash function implementation.
    
    This class implements the SHA-224 cryptographic hash function as specified
    in FIPS 180-4. It provides methods for incremental hashing as well as
    one-shot hashing of data.
    """
    
    def __init__(self):
        """Initialize SHA-224 hash state."""
        # SHA-224 initial hash values (in hex)
        # These values are the second 32 bits of the fractional parts of the
        # square roots of the 9th through 16th prime numbers
        self.h = [
            0xc1059ed8, 0x367cd507, 0x3070dd17, 0xf70e5939,
            0xffc00b31, 0x68581511, 0x64f98fa7, 0xbefa4fa4
        ]
        
        # SHA-256 round constants
        # First 32 bits of the fractional parts of the cube roots of the 
        # first 64 prime numbers
        self.k = [
            0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
            0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
            0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
            0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
            0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
            0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
            0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
            0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
            0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
            0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
            0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
            0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
            0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
            0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
            0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
            0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
        ]
        
        # Initialize buffers
        self.buffer = bytearray()
        self.counter = 0  # Number of bits processed
    
    def update(self, data: Union[str, bytes, bytearray]) -> 'SHA224':
        """
        Update the hash with more data.
        
        Args:
            data: The data to hash. Can be a string or bytes.
            
        Returns:
            self: For method chaining.
        """
        # Convert string to bytes if necessary
        if isinstance(data, str):
            data = data.encode('utf-8')
        
        # Update counter
        self.counter += len(data) * 8
        
        # Update buffer with new data
        self.buffer.extend(data)
        
        # Process as many complete blocks as possible
        while len(self.buffer) >= 64:
            self._process_block(self.buffer[:64])
            self.buffer = self.buffer[64:]
        
        return self
    
    def digest(self) -> bytes:
        """
        Finalize the hash and return the digest as bytes.
        
        Returns:
            bytes: The SHA-224 digest (28 bytes).
        """
        # Make a copy of the current state
        h_copy = self.h.copy()
        buffer_copy = self.buffer.copy()
        counter_copy = self.counter
        
        # Pad the message
        buffer_copy.append(0x80)  # Append bit '1'
        
        # Append zeros until message length is 56 bytes (mod 64)
        while len(buffer_copy) % 64 != 56:
            buffer_copy.append(0)
        
        # Append the length as a 64-bit big-endian integer
        buffer_copy.extend(struct.pack('>Q', counter_copy))
        
        # Process the final block(s)
        for i in range(0, len(buffer_copy), 64):
            self._process_block(buffer_copy[i:i+64], h_copy)
        
        # Concatenate the hash values (truncating to 224 bits)
        result = b''
        for i in range(7):  # Only use 7 of the 8 values for SHA-224
            result += struct.pack('>I', h_copy[i])
        
        return result
    
    def hexdigest(self) -> str:
        """
        Finalize the hash and return the digest as a hex string.
        
        Returns:
            str: The SHA-224 digest as a 56-character hex string.
        """
        return binascii.hexlify(self.digest()).decode('ascii')
    
    def _process_block(self, block: bytes, h: Optional[List[int]] = None) -> None:
        """
        Process a single 64-byte block.
        
        Args:
            block: 64-byte block to process.
            h: Hash state to update. If None, uses the instance's state.
        """
        if h is None:
            h = self.h
        
        # Prepare message schedule
        w = [0] * 64
        
        # Copy block into first 16 words w[0..15]
        for i in range(16):
            w[i] = struct.unpack('>I', block[i*4:i*4+4])[0]
        
        # Extend the first 16 words into remaining 48 words
        for i in range(16, 64):
            s0 = self._right_rotate(w[i-15], 7) ^ self._right_rotate(w[i-15], 18) ^ (w[i-15] >> 3)
            s1 = self._right_rotate(w[i-2], 17) ^ self._right_rotate(w[i-2], 19) ^ (w[i-2] >> 10)
            w[i] = (w[i-16] + s0 + w[i-7] + s1) & 0xffffffff
        
        # Initialize working variables
        a, b, c, d, e, f, g, h_ = h
        
        # Main loop
        for i in range(64):
            S1 = self._right_rotate(e, 6) ^ self._right_rotate(e, 11) ^ self._right_rotate(e, 25)
            ch = (e & f) ^ ((~e) & g)
            temp1 = (h_ + S1 + ch + self.k[i] + w[i]) & 0xffffffff
            
            S0 = self._right_rotate(a, 2) ^ self._right_rotate(a, 13) ^ self._right_rotate(a, 22)
            maj = (a & b) ^ (a & c) ^ (b & c)
            temp2 = (S0 + maj) & 0xffffffff
            
            h_ = g
            g = f
            f = e
            e = (d + temp1) & 0xffffffff
            d = c
            c = b
            b = a
            a = (temp1 + temp2) & 0xffffffff
        
        # Update hash values
        h[0] = (h[0] + a) & 0xffffffff
        h[1] = (h[1] + b) & 0xffffffff
        h[2] = (h[2] + c) & 0xffffffff
        h[3] = (h[3] + d) & 0xffffffff
        h[4] = (h[4] + e) & 0xffffffff
        h[5] = (h[5] + f) & 0xffffffff
        h[6] = (h[6] + g) & 0xffffffff
        h[7] = (h[7] + h_) & 0xffffffff
    
    @staticmethod
    def _right_rotate(value: int, shift: int) -> int:
        """
        Right rotate a 32-bit integer by shift bits.
        
        Args:
            value: The value to rotate.
            shift: The number of bits to rotate by.
            
        Returns:
            int: The rotated value.
        """
        return ((value >> shift) | (value << (32 - shift))) & 0xffffffff
    
    @classmethod
    def hash(cls, data: Union[str, bytes, bytearray]) -> str:
        """
        Compute the SHA-224 hash of data in one step.
        
        Args:
            data: The data to hash. Can be a string or bytes.
            
        Returns:
            str: The SHA-224 digest as a hex string.
        """
        return cls().update(data).hexdigest()
    
    @classmethod
    def verify(cls, data: Union[str, bytes, bytearray], hash_value: str) -> bool:
        """
        Verify if the SHA-224 hash of data matches the expected hash.
        
        Args:
            data: The data to verify.
            hash_value: The expected SHA-224 hash (hex string).
            
        Returns:
            bool: True if the hash matches, False otherwise.
        """
        # Compute the hash of the data
        computed_hash = cls.hash(data)
        
        # Normalize the hash values (converting to lowercase)
        computed_hash = computed_hash.lower()
        expected_hash = hash_value.lower()
        
        # Constant-time comparison to prevent timing attacks
        if len(computed_hash) != len(expected_hash):
            return False
        
        result = 0
        for a, b in zip(computed_hash, expected_hash):
            result |= ord(a) ^ ord(b)
        
        return result == 0


# Simple unit test and example usage
if __name__ == "__main__":
    # Test vectors
    test_vectors = [
        (b"", "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f"),
        (b"abc", "23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7"),
        (b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq",
         "75388b16512776cc5dba5da1fd890150b0c6455cb4f58b1952522525"),
        (b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu",
         "c97ca9a559850ce97a04a96def6d99a9e0e0e2ab14e6b8df265fc0b3"),
    ]
    
    # Run tests
    for i, (data, expected) in enumerate(test_vectors):
        result = SHA224.hash(data)
        assert result == expected, f"Test {i+1} failed: expected {expected}, got {result}"
        print(f"Test {i+1} passed!")
    
    # Example usage
    message = "Hello, World!"
    hash_value = SHA224.hash(message)
    print(f"SHA-224 hash of '{message}': {hash_value}")
    
    # Verify example
    is_valid = SHA224.verify(message, hash_value)
    print(f"Verification result: {is_valid}")