import struct
from Cryptodome.Protocol.KDF import PBKDF2
from Cryptodome.Cipher import AES
from Cryptodome.Hash import HMAC
from Cryptodome.Hash.SHA1 import SHA1Hash
from Cryptodome.Util import Counter
from Cryptodome import Random
from .zipfile import (
ZIP_BZIP2,
ZIP_LZMA,
BadZipFile,
BaseZipDecrypter,
ZipFile,
ZipInfo,
ZipExtFile,
)
WZ_AES = 'WZ_AES'
WZ_AES_COMPRESS_TYPE = 99
WZ_AES_V1 = 0x0001
WZ_AES_V2 = 0x0002
WZ_AES_VENDOR_ID = b'AE'
EXTRA_WZ_AES = 0x9901
WZ_SALT_LENGTHS = {
1: 8, # 128 bit
2: 12, # 192 bit
3: 16, # 256 bit
}
WZ_KEY_LENGTHS = {
1: 16, # 128 bit
2: 24, # 192 bit
3: 32, # 256 bit
}
class AESZipDecrypter(BaseZipDecrypter):
hmac_size = 10
def __init__(self, zinfo, pwd, encryption_header):
self.filename = zinfo.filename
key_length = WZ_KEY_LENGTHS[zinfo.wz_aes_strength]
salt_length = WZ_SALT_LENGTHS[zinfo.wz_aes_strength]
salt = struct.unpack(
"<{}s".format(salt_length),
encryption_header[:salt_length]
)[0]
pwd_verify_length = 2
pwd_verify = encryption_header[salt_length:]
dkLen = 2*key_length + pwd_verify_length
keymaterial = PBKDF2(pwd, salt, count=1000, dkLen=dkLen)
encpwdverify = keymaterial[2*key_length:]
if encpwdverify != pwd_verify:
raise RuntimeError("Bad password for file %r" % zinfo.filename)
enckey = keymaterial[:key_length]
self.decypter = AES.new(
enckey,
AES.MODE_CTR,
counter=Counter.new(nbits=128, little_endian=True)
)
encmac_key = keymaterial[key_length:2*key_length]
self.hmac = HMAC.new(encmac_key, digestmod=SHA1Hash())
@staticmethod
def encryption_header_length(zinfo):
# salt_length + pwd_verify_length
salt_length = WZ_SALT_LENGTHS[zinfo.wz_aes_strength]
return salt_length + 2
def decrypt(self, data):
self.hmac.update(data)
return self.decypter.decrypt(data)
def check_hmac(self, hmac_check):
if self.hmac.digest()[:10] != hmac_check:
raise BadZipFile("Bad HMAC check for file %r" % self.filename)
class BaseZipEncrypter:
def update_zipinfo(self, zipinfo):
raise NotImplementedError(
'BaseZipEncrypter implementations must implement `update_zipinfo`.'
)
def encrypt(self, data):
raise NotImplementedError(
'BaseZipEncrypter implementations must implement `encrypt`.'
)
def encryption_header(self):
raise NotImplementedError(
'BaseZipEncrypter implementations must implement '
'`encryption_header`.'
)
def flush(self):
return b''
class AESZipEncrypter(BaseZipEncrypter):
hmac_size = 10
def __init__(self, pwd, nbits=256, force_wz_aes_version=None):
if not pwd:
raise RuntimeError(
'%s encryption requires a password.' % WZ_AES
)
if nbits not in (128, 192, 256):
raise RuntimeError(
"`nbits` must be one of 128, 192, 256. Got '%s'" % nbits
)
self.force_wz_aes_version = force_wz_aes_version
salt_lengths = {
128: 8,
192: 12,
256: 16,
}
self.salt_length = salt_lengths[nbits]
key_lengths = {
128: 16,
192: 24,
256: 32,
}
key_length = key_lengths[nbits]
aes_strengths = {
128: 1,
192: 2,
256: 3,
}
self.aes_strength = aes_strengths[nbits]
self.salt = Random.new().read(self.salt_length)
pwd_verify_length = 2
dkLen = 2 * key_length + pwd_verify_length
keymaterial = PBKDF2(pwd, self.salt, count=1000, dkLen=dkLen)
self.encpwdverify = keymaterial[2*key_length:]
enckey = keymaterial[:key_length]
self.encrypter = AES.new(
enckey,
AES.MODE_CTR,
counter=Counter.new(nbits=128, little_endian=True)
)
encmac_key = keymaterial[key_length:2*key_length]
self.hmac = HMAC.new(encmac_key, digestmod=SHA1Hash())
def update_zipinfo(self, zipinfo):
zipinfo.wz_aes_vendor_id = WZ_AES_VENDOR_ID
zipinfo.wz_aes_strength = self.aes_strength
if self.force_wz_aes_version is not None:
zipinfo.wz_aes_version = self.force_wz_aes_version
def encryption_header(self):
return self.salt + self.encpwdverify
def encrypt(self, data):
data = self.encrypter.encrypt(data)
self.hmac.update(data)
return data
def flush(self):
return struct.pack('<%ds' % self.hmac_size, self.hmac.digest()[:10])
[docs]
class AESZipInfo(ZipInfo):
"""Class with attributes describing each file in the ZIP archive."""
# __slots__ on subclasses only need to contain the additional slots.
__slots__ = (
'wz_aes_version',
'wz_aes_vendor_id',
'wz_aes_strength',
# 'wz_aes_actual_compression_type',
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wz_aes_version = None
self.wz_aes_vendor_id = None
self.wz_aes_strength = None
# self.wz_aes_actual_compression_type = counts[3]
[docs]
def encode_central_directory(self, *, crc, compress_type, extra_data,
**kwargs):
wz_aes_extra, crc, compress_type = self.encode_extra(
crc, compress_type)
return super().encode_central_directory(
crc=crc,
compress_type=compress_type,
extra_data=extra_data+wz_aes_extra,
**kwargs)
class AESZipExtFile(ZipExtFile):
def setup_aeszipdecrypter(self):
if not self._pwd:
raise RuntimeError(
'File %r is encrypted with %s encryption and requires a '
'password.' % (self.name, WZ_AES)
)
encryption_header_length = AESZipDecrypter.encryption_header_length(
self._zinfo)
self.encryption_header = self._fileobj.read(encryption_header_length)
# Adjust read size for encrypted files since the start of the file
# may be used for the encryption/password information.
self._orig_compress_left -= encryption_header_length
# Also remove the hmac length from the end of the file.
self._orig_compress_left -= AESZipDecrypter.hmac_size
return AESZipDecrypter
def setup_decrypter(self):
if self._zinfo.wz_aes_version is not None:
return self.setup_aeszipdecrypter()
return super().setup_decrypter()
def check_wz_aes(self):
if self._zinfo.compress_type == ZIP_LZMA:
# LZMA may have an end of stream marker or padding. Make sure we
# read that to get the proper HMAC of the compressed byte stream.
while self._compress_left > 0:
data = self._read2(self.MIN_READ_SIZE)
# but we don't want to find any more data here.
data = self._decompressor.decompress(data)
if data:
raise BadZipFile(
"More data found than indicated by uncompressed size for "
"'{}'".format(self.filename)
)
hmac_check = self._fileobj.read(self._decrypter.hmac_size)
self._decrypter.check_hmac(hmac_check)
def check_integrity(self):
if self._zinfo.wz_aes_version is not None:
self.check_wz_aes()
if self._expected_crc is not None and self._expected_crc != 0:
# Not part of the spec but still check the CRC if it is
# supplied when WZ_AES_V2 is specified (no CRC check and CRC
# should be 0).
self.check_crc()
elif self._zinfo.wz_aes_version != WZ_AES_V2:
# CRC value should be 0 for AES vendor version 2.
self.check_crc()
else:
super().check_integrity()
[docs]
class AESZipFile(ZipFile):
zipinfo_cls = AESZipInfo
zipextfile_cls = AESZipExtFile
def __init__(self, *args, **kwargs):
encryption = kwargs.pop('encryption', None)
encryption_kwargs = kwargs.pop('encryption_kwargs', None)
super().__init__(*args, **kwargs)
self.encryption = encryption
self.encryption_kwargs = encryption_kwargs
[docs]
def get_encrypter(self):
if self.encryption == WZ_AES:
if self.encryption_kwargs is None:
encryption_kwargs = {}
else:
encryption_kwargs = self.encryption_kwargs
return AESZipEncrypter(pwd=self.pwd, **encryption_kwargs)