#CHIPSEC: Platform Security Assessment Framework
#Copyright (c) 2010-2022, Intel Corporation
#
#This program is free software; you can redistribute it and/or
#modify it under the terms of the GNU General Public License
#as published by the Free Software Foundation; Version 2.
#
#This program is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#GNU General Public License for more details.
#
#You should have received a copy of the GNU General Public License
#along with this program; if not, write to the Free Software
#Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
#Contact information:
#chipsec@intel.com
#

"""
Access to of PCI/PCIe device hierarchy
- enumerating PCI/PCIe devices
- read/write access to PCI configuration headers/registers
- enumerating PCI expansion (option) ROMs
- identifying PCI/PCIe devices MMIO and I/O ranges (BARs)

usage:
    >>> self.cs.pci.read_byte( 0, 0, 0, 0x88 )
    >>> self.cs.pci.write_byte( 0, 0, 0, 0x88, 0x1A )
    >>> self.cs.pci.enumerate_devices()
    >>> self.cs.pci.enumerate_xroms()
    >>> self.cs.pci.find_XROM( 2, 0, 0, True, True, 0xFED00000 )
    >>> self.cs.pci.get_device_bars( 2, 0, 0 )
    >>> self.cs.pci.get_DIDVID( 2, 0, 0 )
    >>> self.cs.pci.is_enabled( 2, 0, 0 )
"""

import struct
from collections import namedtuple
import itertools

from chipsec import defines
from chipsec.logger import logger, pretty_print_hex_buffer
from chipsec.file   import write_file
from chipsec.hal.pcidb  import VENDORS, DEVICES
from chipsec.exceptions import OsHelperError
from chipsec.defines    import is_all_ones

#
# PCI configuration header registers
#

# Common (type 0/1) registers
PCI_HDR_VID_OFF            = 0x0
PCI_HDR_DID_OFF            = 0x2
PCI_HDR_CMD_OFF            = 0x4
PCI_HDR_STS_OFF            = 0x6
PCI_HDR_RID_OFF            = 0x8
PCI_HDR_CLSCODE_OFF        = 0x9
PCI_HDR_PI_OFF             = 0x9
PCI_HDR_SUB_CLS_OFF        = 0xA
PCI_HDR_CLS_OFF            = 0xB
PCI_HDR_CLSIZE_OFF         = 0xC
PCI_HDR_MLT_OFF            = 0xD
PCI_HDR_TYPE_OFF           = 0xE
PCI_HDR_BIST_OFF           = 0xF
PCI_HDR_CAP_OFF            = 0x34
PCI_HDR_INTRLN_OFF         = 0x3C
PCI_HDR_INTRPIN_OFF        = 0x3D
PCI_HDR_BAR0_LO_OFF        = 0x10
PCI_HDR_BAR0_HI_OFF        = 0x14

# PCIe BAR register fields
PCI_HDR_BAR_CFGBITS_MASK   = 0xF

PCI_HDR_BAR_IOMMIO_MASK    = 0x1
PCI_HDR_BAR_IOMMIO_MMIO    = 0
PCI_HDR_BAR_IOMMIO_IO      = 1

PCI_HDR_BAR_TYPE_MASK      = (0x3<<1)
PCI_HDR_BAR_TYPE_SHIFT     = 1
PCI_HDR_BAR_TYPE_64B       = 2
PCI_HDR_BAR_TYPE_1MB       = 1
PCI_HDR_BAR_TYPE_32B       = 0

PCI_HDR_BAR_BASE_MASK_MMIO64 = 0xFFFFFFFFFFFFFFF0
PCI_HDR_BAR_BASE_MASK_MMIO   = 0xFFFFFFF0
PCI_HDR_BAR_BASE_MASK_IO     = 0xFFFC

# Type 0 specific registers
PCI_HDR_TYPE0_BAR1_LO_OFF  = 0x18
PCI_HDR_TYPE0_BAR1_HI_OFF  = 0x1C
PCI_HDR_TYPE0_BAR2_LO_OFF  = 0x20
PCI_HDR_TYPE0_BAR2_HI_OFF  = 0x24
PCI_HDR_TYPE0_XROM_BAR_OFF = 0x30

# Type 1 specific registers
PCI_HDR_TYPE1_XROM_BAR_OFF = 0x38

# Field defines

PCI_HDR_CMD_MS_MASK        = 0x2

PCI_HDR_TYPE_TYPE_MASK     = 0x7F
PCI_HDR_TYPE_MF_MASK       = 0x80

PCI_TYPE0                  = 0x0
PCI_TYPE1                  = 0x1

PCI_HDR_XROM_BAR_EN_MASK   = 0x00000001
PCI_HDR_XROM_BAR_BASE_MASK = 0xFFFFF000

PCI_HDR_BAR_STEP           = 0x4


#
# Generic/standard PCI Expansion (Option) ROM
#

XROM_SIGNATURE       = 0xAA55
PCI_XROM_HEADER_FMT  = '<H22sH'
PCI_XROM_HEADER_SIZE = struct.calcsize( PCI_XROM_HEADER_FMT )
class PCI_XROM_HEADER( namedtuple('PCI_XROM_HEADER', 'Signature ArchSpecific PCIROffset') ):
    __slots__ = ()
    def __str__(self):
        return """
PCI XROM
-----------------------------------
Signature       : 0x{:04X} (= 0xAA55)
ArchSpecific    : {}
PCIR Offset     : 0x{:04X}
""".format( self.Signature, self.ArchSpecific.encode('hex').upper(), self.PCIROffset )

# @TBD: PCI Data Structure

#
# EFI specific PCI Expansion (Option) ROM
#

EFI_XROM_SIGNATURE   = 0x0EF1
EFI_XROM_HEADER_FMT  = '<HHIHHHBHH'
EFI_XROM_HEADER_SIZE = struct.calcsize( EFI_XROM_HEADER_FMT )
class EFI_XROM_HEADER( namedtuple('EFI_XROM_HEADER', 'Signature InitSize EfiSignature EfiSubsystem EfiMachineType CompressType Reserved EfiImageHeaderOffset PCIROffset') ):
    __slots__ = ()
    def __str__(self):
        return """
EFI PCI XROM
---------------------------------------
Signature           : 0x{:04X} (= 0xAA55)
Init Size           : 0x{:04X} (x 512 B)
EFI Signature       : 0x{:08X} (= 0x0EF1)
EFI Subsystem       : 0x{:04X}
EFI Machine Type    : 0x{:04X}
Compression Type    : 0x{:04X}
Reserved            : 0x{:02X}
EFI Image Hdr Offset: 0x{:04X}
PCIR Offset         : 0x{:04X}
""".format( self.Signature, self.InitSize, self.EfiSignature, self.EfiSubsystem, self.EfiMachineType, self.CompressType, self.Reserved, self.EfiImageHeaderOffset, self.PCIROffset )

#
# Legacy PCI Expansion (Option) ROM
#

XROM_HEADER_FMT  = '<HBI17sH'
XROM_HEADER_SIZE = struct.calcsize( XROM_HEADER_FMT )
class XROM_HEADER( namedtuple('XROM_HEADER', 'Signature InitSize InitEP Reserved PCIROffset') ):
    __slots__ = ()
    def __str__(self):
        return """
XROM
--------------------------------------
Signature           : 0x{:04X}
Init Size           : 0x{:02X} (x 512 B)
Init Entry-point    : 0x{:08X}
Reserved            : {}
PCIR Offset         : 0x{:04X}
""".format( self.Signature, self.InitSize, self.InitEP, self.Reserved.encode('hex').upper(), self.PCIROffset )


class XROM(object):
    def __init__(self, bus, dev, fun, en, base, size):
        self.bus    = bus
        self.dev    = dev
        self.fun    = fun
        self.vid    = 0xFFFF
        self.did    = 0xFFFF
        self.en     = en
        self.base   = base
        self.size   = size
        self.header = None


def get_vendor_name_by_vid( vid ):
    if vid in VENDORS:
        return VENDORS[vid]
    return ''

def get_device_name_by_didvid( vid, did ):
    if vid in DEVICES:
        if did in DEVICES[vid]:
            return DEVICES[vid][did]
    return ''

def print_pci_devices( _devices ):
    logger().log( "BDF     | VID:DID   | Vendor                       | Device" )
    logger().log( "-------------------------------------------------------------------------" )
    for (b, d, f, vid, did) in _devices:
        vendor_name = get_vendor_name_by_vid( vid )
        device_name = get_device_name_by_didvid( vid, did )
        logger().log( "{:02X}:{:02X}.{:X} | {:04X}:{:04X} | {:28} | {}".format(b, d, f, vid, did, vendor_name, device_name) )

def print_pci_XROMs( _xroms ):
    if len(_xroms) == 0: return
    logger().log( "BDF     | VID:DID   | XROM base | XROM size | en " )
    logger().log( "-------------------------------------------------" )
    for xrom in _xroms:
        logger().log( "{:02X}:{:02X}.{:X} | {:04X}:{:04X} | {:08X}  | {:08X}  | {:d}".format(xrom.bus, xrom.dev, xrom.fun, xrom.vid, xrom.did, xrom.base, xrom.size, xrom.en) )


class Pci:

    def __init__( self, cs ):
        self.cs     = cs
        self.helper = cs.helper

    #
    # Access to PCI configuration registers
    #

    def read_dword(self, bus, device, function, address ):
        value = self.helper.read_pci_reg( bus, device, function, address, 4 )
        if logger().HAL:
            logger().log( "[pci] reading B/D/F: {:d}/{:d}/{:d}, offset: 0x{:02X}, value: 0x{:08X}".format(bus, device, function, address, value) )
        return value

    def read_word(self, bus, device, function, address ):
        word_value = self.helper.read_pci_reg( bus, device, function, address, 2 )
        if logger().HAL:
            logger().log( "[pci] reading B/D/F: {:d}/{:d}/{:d}, offset: 0x{:02X}, value: 0x{:04X}".format(bus, device, function, address, word_value) )
        return word_value

    def read_byte(self, bus, device, function, address ):
        byte_value = self.helper.read_pci_reg( bus, device, function, address, 1 )
        if logger().HAL:
            logger().log( "[pci] reading B/D/F: {:d}/{:d}/{:d}, offset: 0x{:02X}, value: 0x{:02X}".format(bus, device, function, address, byte_value) )
        return byte_value

    def write_byte(self, bus, device, function, address, byte_value ):
        self.helper.write_pci_reg( bus, device, function, address, byte_value, 1 )
        if logger().HAL:
            logger().log( "[pci] writing B/D/F: {:d}/{:d}/{:d}, offset: 0x{:02X}, value: 0x{:02X}".format(bus, device, function, address, byte_value) )
        return

    def write_word(self, bus, device, function, address, word_value ):
        self.helper.write_pci_reg( bus, device, function, address, word_value, 2 )
        if logger().HAL:
            logger().log( "[pci] writing B/D/F: {:d}/{:d}/{:d}, offset: 0x{:02X}, value: 0x{:04X}".format(bus, device, function, address, word_value) )
        return

    def write_dword( self, bus, device, function, address, dword_value ):
        self.helper.write_pci_reg( bus, device, function, address, dword_value, 4 )
        if logger().HAL:
            logger().log( "[pci] writing B/D/F: {:d}/{:d}/{:d}, offset: 0x{:02X}, value: 0x{:08X}".format(bus, device, function, address, dword_value) )
        return


    #
    # Enumerating PCI devices and dumping configuration space
    #

    def enumerate_devices(self, bus=None, device=None, function=None):
        devices = []

        if bus is not None:
            bus_range = [bus]
        else:
            bus_range = range(256)
        if device is not None:
            dev_range = [device]
        else:
            dev_range = range(32)
        if function is not None:
            func_range = [function]
        else:
            func_range = range(8)

        for b, d, f in itertools.product(bus_range, dev_range, func_range):
            try:
                did_vid = self.read_dword(b, d, f, 0x0)
                if 0xFFFFFFFF != did_vid:
                    vid = did_vid & 0xFFFF
                    did = (did_vid >> 16) & 0xFFFF
                    devices.append((b, d, f, vid, did))
            except OsHelperError:
                if logger().HAL:
                    logger().log("[pci] unable to access B/D/F: {:d}/{:d}/{:d}".format(b, d, f))
        return devices

    def dump_pci_config( self, bus, device, function ):
        cfg = []
        for off in range(0, 0x100, 4):
            tmp_val = self.read_dword(bus, device, function, off)
            for shift in range(0, 32, 8):
                cfg.append((tmp_val >> shift) & 0xFF)
        return cfg

    def print_pci_config_all( self ):
        logger().log( "[pci] enumerating available PCI devices..." )
        pci_devices = self.enumerate_devices()
        for (b, d, f, vid, did) in pci_devices:
            cfg_buf = self.dump_pci_config( b, d, f )
            logger().log( "\n[pci] PCI device {:02X}:{:02X}.{:02X} configuration:".format(b, d, f) )
            pretty_print_hex_buffer( cfg_buf )


    #
    # PCI Expansion ROM functions
    #

    def parse_XROM( self, xrom, xrom_dump=False ):
        xrom_sig = self.cs.mem.read_physical_mem_word( xrom.base )
        if xrom_sig != XROM_SIGNATURE: return None
        xrom_hdr_buf = self.cs.mem.read_physical_mem( xrom.base, PCI_XROM_HEADER_SIZE )
        xrom_hdr = PCI_XROM_HEADER( *struct.unpack_from( PCI_XROM_HEADER_FMT, xrom_hdr_buf ) )
        if xrom_dump:
            xrom_fname = 'xrom_{:X}-{:X}-{:X}_{:X}{:X}.bin'.format(xrom.bus, xrom.dev, xrom.fun, xrom.vid, xrom.did)
            xrom_buf = self.cs.mem.read_physical_mem( xrom.base, xrom.size ) # use xrom_hdr.InitSize ?
            write_file( xrom_fname, xrom_buf )
        return xrom_hdr

    def find_XROM( self, bus, dev, fun, try_init=False, xrom_dump=False, xrom_addr=None ):
        # return results
        xrom_found, xrom = False, None

        if logger().HAL: logger().log( "[pci] checking XROM in {:02X}:{:02X}.{:02X}".format(bus, dev, fun) )

        cmd = self.read_word(bus, dev, fun, PCI_HDR_CMD_OFF)
        ms = ((cmd & PCI_HDR_CMD_MS_MASK) == PCI_HDR_CMD_MS_MASK)
        if logger().HAL: logger().log( "[pci]   PCI CMD (memory space = {:d}): 0x{:04X}".format(ms, cmd) )

        hdr_type = self.read_byte(bus, dev, fun, PCI_HDR_TYPE_OFF)
        _mf   = hdr_type & PCI_HDR_TYPE_MF_MASK
        _type = hdr_type & PCI_HDR_TYPE_TYPE_MASK
        xrom_bar_off = PCI_HDR_TYPE1_XROM_BAR_OFF if _type == PCI_TYPE1 else PCI_HDR_TYPE0_XROM_BAR_OFF

        xrom_bar = self.read_dword( bus, dev, fun, xrom_bar_off )
        orig_xrom_bar = xrom_bar
        xrom_exists = (xrom_bar != 0)

        if xrom_exists:
            if logger().HAL: logger().log( "[pci]   device programmed XROM BAR: 0x{:08X}".format(xrom_bar) )
        else:
            if logger().HAL: logger().log( "[pci]   device didn't program XROM BAR: 0x{:08X}".format(xrom_bar) )
            if try_init:
                self.write_dword( bus, dev, fun, xrom_bar_off, PCI_HDR_XROM_BAR_BASE_MASK )
                xrom_bar = self.read_dword( bus, dev, fun, xrom_bar_off )
                xrom_exists = (xrom_bar != 0)
                if logger().HAL: logger().log( "[pci]   returned 0x{:08X} after writing {:08X}".format(xrom_bar, PCI_HDR_XROM_BAR_BASE_MASK) )
                if xrom_exists and (xrom_addr is not None):
                    # device indicates XROM may exist. Initialize its base with supplied MMIO address
                    size_align = ~(xrom_bar & PCI_HDR_XROM_BAR_BASE_MASK) # actual XROM alignment
                    if (xrom_addr & size_align) != 0:
                        logger().warn( "XROM address 0x{:08X} must be aligned at 0x{:08X}".format(xrom_addr, size_align) )
                        return False, None
                    self.write_dword( bus, dev, fun, xrom_bar_off, (xrom_addr|PCI_HDR_XROM_BAR_EN_MASK) )
                    xrom_bar = self.read_dword( bus, dev, fun, xrom_bar_off )
                    if logger().HAL: logger().log( "[pci]   programmed XROM BAR with 0x{:08X}".format(xrom_bar) )

                # restore original value of XROM BAR
                #if orig_xrom_bar != xrom_bar:
                #    self.write_dword( bus, dev, fun, xrom_bar_off, orig_xrom_bar )

        #
        # At this point, a device indicates that XROM exists. Let's check if XROM is really there
        #
        xrom_en   = ((xrom_bar & PCI_HDR_XROM_BAR_EN_MASK) == 0x1)
        xrom_base = (xrom_bar & PCI_HDR_XROM_BAR_BASE_MASK)
        xrom_size = ~xrom_base + 1

        if xrom_exists:
            if logger().HAL: logger().log( "[pci]   XROM: BAR = 0x{:08X}, base = 0x{:08X}, size = 0x{:X}, en = {:d}".format(xrom_bar, xrom_base, xrom_size, xrom_en) )
            xrom = XROM(bus, dev, fun, xrom_en, xrom_base, xrom_size)
            if xrom_en and (xrom_base != PCI_HDR_XROM_BAR_BASE_MASK):
                xrom.header = self.parse_XROM( xrom, xrom_dump )
                xrom_found  = (xrom.header is not None)
                if xrom_found:
                    if logger().HAL:
                        logger().log( "[pci]   XROM found at 0x{:08X}".format(xrom_base) )
                        logger().log( xrom.header )

        if not xrom_found:
            if logger().HAL: logger().log( "[pci]   XROM was not found" )

        return xrom_found, xrom

    def enumerate_xroms( self, try_init=False, xrom_dump=False, xrom_addr=None ):
        pci_xroms = []
        logger().log( "[pci] enumerating available PCI devices..." )
        pci_devices = self.enumerate_devices()
        for (b, d, f, vid, did) in pci_devices:
            exists, xrom = self.find_XROM( b, d, f, try_init, xrom_dump, xrom_addr )
            if exists:
                xrom.vid = vid
                xrom.did = did
                pci_xroms.append( xrom )
        return pci_xroms

    #
    # Enumerating PCI device MMIO and I/O ranges (BARs)
    #

    #
    # Calculates actual size of MMIO BAR range
    # @TODO: for 64-bit BARs need to write both BAR registers for size calculation
    def calc_bar_size(self, bus, dev, fun, off, reg):
        self.write_dword(bus, dev, fun, off, defines.MASK_32b)
        reg1 = self.read_dword(bus, dev, fun, off)
        self.write_dword(bus, dev, fun, off, reg)
        size = (~(reg1&PCI_HDR_BAR_BASE_MASK_MMIO) & defines.MASK_32b) + 1
        return size
    #
    # Returns all I/O and MMIO BARs defined in the PCIe header of the device
    # Returns array of elements in format (BAR_address, isMMIO, is64bit, BAR_reg_offset, BAR_reg_value)
    # @TODO: need to account for Type 0 vs Type 1 headers
    def get_device_bars( self, bus, dev, fun, bCalcSize=False ):
        _bars = []
        off   = PCI_HDR_BAR0_LO_OFF
        size  = defines.BOUNDARY_4KB
        while off <= PCI_HDR_TYPE0_BAR2_HI_OFF:
            reg = self.read_dword(bus, dev, fun, off)
            if reg and reg != defines.MASK_32b:
                # BAR is initialized
                isMMIO = (PCI_HDR_BAR_IOMMIO_MMIO == (reg & PCI_HDR_BAR_IOMMIO_MASK))
                if isMMIO:
                    # MMIO BAR
                    _type = (reg&PCI_HDR_BAR_TYPE_MASK) >> PCI_HDR_BAR_TYPE_SHIFT
                    if PCI_HDR_BAR_TYPE_64B == _type:
                        # 64-bit MMIO BAR
                        if bCalcSize: size = self.calc_bar_size(bus, dev, fun, off, reg)
                        off += PCI_HDR_BAR_STEP
                        reg_hi = self.read_dword( bus, dev, fun, off )
                        reg |= (reg_hi << 32)
                        base = (reg & PCI_HDR_BAR_BASE_MASK_MMIO64)
                        _bars.append( (base, isMMIO, True, off -PCI_HDR_BAR_STEP, reg, size) )
                    elif PCI_HDR_BAR_TYPE_1MB == _type:
                        # MMIO BAR below 1MB - not supported
                        pass
                    elif PCI_HDR_BAR_TYPE_32B == _type:
                        # 32-bit only MMIO BAR
                        base = (reg & PCI_HDR_BAR_BASE_MASK_MMIO)
                        if bCalcSize: size = self.calc_bar_size(bus, dev, fun, off, reg)
                        _bars.append( (base, isMMIO, False, off, reg, size) )
                else:
                    # I/O BAR
                    # @TODO: calculate I/O BAR size, hardcoded to 0x100 for now
                    base = (reg & PCI_HDR_BAR_BASE_MASK_IO)
                    _bars.append( (base, isMMIO, False, off, reg, 0x100) )
            off += PCI_HDR_BAR_STEP
        return _bars

    def get_DIDVID( self, bus, dev, fun ):
        didvid = self.read_dword( bus, dev, fun, 0x0 )
        vid = didvid & 0xFFFF
        did = (didvid >> 16) & 0xFFFF
        return (did, vid)

    def is_enabled( self, bus, dev, fun ):
        (did, vid) = self.get_DIDVID( bus, dev, fun )
        if (is_all_ones(vid,2)) or (is_all_ones(did,2)):
            return False
        return True
