#!/usr/bin/env python
#
# Copyright (c) 2012 Intel, Inc.
# License: GPLv2
# Author: Artem Bityutskiy <artem.bityutskiy@linux.intel.com>

"""
Write raw image file to the target block device using the block map file.
The block map file contains the list of blocks which have to be written.

TODO:   switch to using the logger
        implement --quiet option
"""

import argparse
import os
import sys
import stat
import time

class FlasherException(Exception): pass

def parse_arguments():
    """ A helper function for 'main()' which parses the input arguments. """

    # The program description
    text = "Flash an image file to a block device using the block map "       \
           "(bmap). Although the bmap is optional, it generally improves "    \
           "flashing speed a lot because with bmap you end up writing only "  \
           "those parts of the image file to the block device, which "        \
           "actually contain useful data. And you do not write useless data " \
           "there. For example, you may have a 4GiB image file, which "       \
           "contains only 100MiB of user data. In this case, with the bmap "  \
           "file you will write only a little bit more than 100MiB of data "  \
           "from the image file to the block device. Which is a lot faster "  \
           "than writing the entire 4GiB image. We say that it is a bit "     \
           "more than 100MiB because there are also file-system meta-data, "  \
           "partition table, etc. The bmap fail is quite human-readable and " \
           "you may take at it. It is an XML document which contains list "   \
           "of blocks in the image file which have to be copied to the "      \
           "block device. There are also nice commentaries which make it "    \
           "easy to understand the bmap file contents."

    parser = argparse.ArgumentParser(description = text)

    # The first command-line argument - block device node
    text = "The block device node to flash the image to."
    parser.add_argument("bdev", help = text)

    # The second command-line argument - image file
    text = "The image file to flash. Supported formats: uncompressed, " \
           ".tar.bz2, .tar.gz, .bz2, .gz."
    parser.add_argument("image", help = text)

    # The --bmap option
    text = "The block map file for the image."
    parser.add_argument("--bmap", help = text)

    # The --no-verify option
    text = "Do not verify the data checksum while writing."
    parser.add_argument("--no-verify", action="store_true", help = text)

    # The --no-sync option
    text = "Do not synchronize the block device after flashing (use "  \
           "carefully and make sure you synchronize the block device " \
           "manually before you unplug it). "
    parser.add_argument("--no-sync", action="store_true", help = text)

    return parser.parse_args()

def message(msg, stream = sys.stdout):
    """ Just a simple wrapper to print a message prefixed with the program
        name. The 'stream' argument may be used to redirect the output to,
        say, stderr. """

    program_name = "bmap-flasher"
    print >> stream, "%s: %s" % (program_name, msg)

def fatal(msg):
    """ Called when a fatal error happens - prints the error message 'msg'
        and terminates the script. """

    message("fatal error: " + msg, sys.stderr)
    raise SystemExit(1)

def human_size(size):
    """ Transform size in bytes into a human-readable form. """

    for modifier in ["KiB", "MiB", "GiB", "TiB"]:
        size /= 1024.0
        if size < 1024:
            return "%.1f %s" % (size, modifier)

    return "%.1f %s" % (size, 'EiB')

def check_bmap_version(version):
    """ Check the bmap file format version and if it is incompatible, refuse
        the bmap file. Also, check if we can verify data checksum with this
        bmap version (checksums were added in version 1.1) """

    supported_version = 1
    major = int(version.split('.', 1)[0])
    if major > supported_version:
        raise FlasherException("only bmap format version up to %d is " \
                               "supported version %d is not supported" \
                               % (supported_version, major))

def copy_data(image, f_image, bdev, f_bdev, first, last, block_size, sha1):
    """ Copy the ['first'-'last'] region of the image file to the same region
        of the block device. The 'first' and 'last' are the block numbers, not
        byte offsets, and 'block_size' is the block size.

        Calculate the SHA1 checksum for the region and make sure it is
        equivalent to 'sha1', unless the 'sha1' argument is None, in which case
        the checksum is not calculated.

        'image' and 'bdev' are names for the image and the block device.
        'f_image' and 'f_bdev' are the corresponding file objects. """

    if sha1:
        import hashlib
        hash_obj = hashlib.sha1()

    start = first * block_size
    f_image.seek(start)
    f_bdev.seek(start)

    chunk_size = (1024 * 1024) / block_size
    blocks_to_write = last - first + 1
    blocks_written = 0
    while blocks_written < blocks_to_write:
        if blocks_written + chunk_size > blocks_to_write:
            chunk_size = blocks_to_write - blocks_written

        try:
            chunk = f_image.read(chunk_size * block_size)
        except IOError as err:
            raise FlasherException("error while reading blocks %d-%d of the " \
                                   "image file '%s': %s" \
                                   % (first + blocks_written,
                                      first + blocks_written + chunk_size,
                                      last, image, err))

        if not chunk:
            raise FlasherException("cannot read block %d, the image file '%s' " \
                                   "is too short" \
                                   % (first + blocks_written, image))

        if sha1:
            hash_obj.update(chunk)

        try:
            f_bdev.write(chunk)
        except IOError as err:
            raise FlasherException("error while writing block %d to block " \
                                   "device '%s': %s" \
                                   % (first + blocks_written, bdev, err))

        blocks_written += chunk_size

    if sha1 and hash_obj.hexdigest() != sha1:
        raise FlasherException("checksum mismatch for blocks range %d-%d: " \
                               "calculated %s, should be %s" \
                               % (first, last, hash_obj.hexdigest(), sha1))

def write_with_bmap(image, f_image, bdev, f_bdev, bmap, f_bmap, verify):
    """ Write the image to the block device using the block map.

        'image', 'bdev', 'bmap' - file names for the image, block device, and
        bmap correspondingly. 'f_image', 'f_bdev', 'f_bmap' are the
        corresponding file objects. This function uses file names only for
        forming message strings, and the file objects are used for the actual
        I/O.

        If the 'verify' argument is True, this function also verifies the
        the checksum for all the data.
        """

    from xml.etree import ElementTree

    xml = ElementTree.parse(f_bmap)

    version = xml.getroot().attrib.get('version')
    check_bmap_version(version)

    # Fetch interesting data from the bmap XML file
    block_size = int(xml.find("BlockSize").text.strip())
    blocks_cnt = int(xml.find("BlocksCount").text.strip())
    mapped_cnt = int(xml.find("MappedBlocksCount").text.strip())
    total_size = blocks_cnt * block_size
    mapped_size = mapped_cnt * block_size

    xml_bmap = xml.find("BlockMap")

    message("block map format version %s" % version)
    message("%d blocks of size %d (%s), mapped %d blocks (%s or %.1f%%)" \
            % (blocks_cnt, block_size, human_size(total_size), mapped_cnt,
               human_size(mapped_size), (mapped_cnt * 100.0) / blocks_cnt))
    message("writing the image to '%s' using bmap file '%s'" % (bdev, bmap))

    # Write the mapped blocks to the block device
    blocks_written = 0
    for xml_element in xml_bmap.findall("Range"):
        blocks_range = xml_element.text.strip()
        # The range of blocks has X - Y format, or it can be just X. First,
        # split the blocks range string and strip white-spaces.
        split = [x.strip() for x in blocks_range.split('-', 1)]
        first = int(split[0])
        if len(split) > 1:
            last = int(split[1])
            if first > last:
                raise FlasherException("bad range (first > last): '%s'" \
                                       % blocks_range)
        else:
            first = last

        if verify and 'sha1' in xml_element.attrib:
            sha1 = xml_element.attrib['sha1']
        else:
            sha1 = None

        copy_data(image, f_image, bdev, f_bdev, first, last, block_size, sha1)

        blocks_written += last - first + 1

    # This is just a sanity check - we should have written exactly 'mapped_cnt'
    # blocks.
    if blocks_written != mapped_cnt:
        raise FlasherException("wrote %u blocks, but should have %u - " \
                               "inconsistent bmap file" \
                               % (blocks_written, mapped_cnt))

def open_image_file(image):
    """ A helper function which opens the image file. """

    try:
        if image.endswith('.tar.gz') or image.endswith('.tar.bz2') or \
           image.endswith('.tgz'):
            import tarfile
            tar = tarfile.open(image, 'r')
            # The tarball is supposed to contain only one single member
            members = tar.getnames()
            if len(members) > 1:
                raise FlasherException("the image tarball '%s' contains more " \
                                       "than one file" % image)
            elif len(members) == 0:
                raise FlasherException("the image tarball '%s' is empty " \
                                       "(no files)" % image)

            return tar.extractfile(members[0])
        if image.endswith('.gz'):
            import gzip
            return gzip.GzipFile(image, 'rb')
        elif image.endswith('.bz2'):
            import bz2
            return bz2.BZ2File(image, 'rb')
        else:
            return open(image, 'rb')
    except IOError as err:
        raise FlasherException("cannot open image file '%s': %s" % (image, err))

def main():
    """ Script entry point. """

    args = parse_arguments()

    # Open the block device in exclusive mode - this will fail if the block
    # device is used by someone, e.g., mounted.
    try:
        f_bdev_raw = os.open(args.bdev, os.O_RDWR | os.O_EXCL)
    except OSError as err:
        fatal("cannot open block device '%s': %s" % (args.bdev, err.strerror))

    try:
        is_block_device = stat.S_ISBLK(os.fstat(f_bdev_raw).st_mode)
    except OSError as err:
        fatal("cannot access block device '%s': %s" \
              % (args.bdev, err.strerror))

    if not is_block_device:
        message("warning!: '%s' is not a block device!" % args.bdev)

    # Turn the block device file descriptor into a file object
    try:
        f_bdev = os.fdopen(f_bdev_raw, "wb")
    except IOError as err:
        fatal("cannot open block device '%s': %s" % (args.bdev, err.strerror))

    try:
        f_image = open_image_file(args.image)
    except FlasherException as err:
        fatal(str(err))

    start_time = time.time()
    if not args.bmap:
        import shutil

        message("no block map given (see the --bmap option")
        message("falling-back to writing entire image to '%s'" % args.bdev)
        try:
            shutil.copyfileobj(f_image, f_bdev)
        except IOError as err:
            fatal("error while copying '%s' to '%s': %s" \
                  % (args.image, args.bdev, err))
    else:
        try:
            f_bmap = open(args.bmap, 'r')
        except IOError as err:
            fatal("cannot open bmap file '%s': %s" % (args.bmap, err.strerror))

        try:
            write_with_bmap(args.image, f_image, args.bdev, f_bdev,
                            args.bmap, f_bmap, not args.no_verify)
        except FlasherException as err:
            fatal(str(err))

        f_bmap.close()

    # Synchronize the block device
    if not args.no_sync:
        message("synchronizing block device '%s'" % args.bdev)
        f_bdev.flush()
        os.fsync(f_bdev.fileno()),

    message("flashing time: %.1f seconds" % (time.time() - start_time))
    f_image.close()
    f_bdev.close()

if __name__ == "__main__":
    sys.exit(main())
