#!/usr/bin/env python3

import json
import os
import sys
import urllib.request
import shutil
import xml.etree.ElementTree as ET
import unicodedata

LOCALES_GIT_REF = "master"


def main():
    if len(sys.argv) < 2 or not os.path.isdir(sys.argv[1]):
        sys.stderr.write(
            "Usage: {0} path/to/zotero/source\n".format(os.path.basename(sys.argv[0]))
        )
        return 1

    source_dir = sys.argv[1]
    schema_dir = os.path.join(source_dir, "resource", "schema")

    csl_locale_base = "https://raw.githubusercontent.com/citation-style-language/locales/{ref}/locales-{lang}.xml"

    # Codes for the language packs that we want to grab
    language_packs = [
        "af-ZA",
        "ar",
        "bg-BG",
        "ca-AD",
        "cs-CZ",
        "cy-GB",
        "da-DK",
        "de-AT",
        "de-CH",
        "de-DE",
        "el-GR",
        "en-GB",
        "en-US",
        "es-CL",
        "es-ES",
        "es-MX",
        "et-EE",
        "eu",
        "fa-IR",
        "fi-FI",
        "fr-CA",
        "fr-FR",
        "he-IL",
        "hi-IN",
        "hr-HR",
        "hu-HU",
        "id-ID",
        "is-IS",
        "it-IT",
        "ja-JP",
        "km-KH",
        "ko-KR",
        "la",
        "lt-LT",
        "lv-LV",
        "mn-MN",
        "nb-NO",
        "nl-NL",
        "nn-NO",
        "pl-PL",
        "pt-BR",
        "pt-PT",
        "ro-RO",
        "ru-RU",
        "sk-SK",
        "sl-SI",
        "sr-RS",
        "sv-SE",
        "th-TH",
        "tr-TR",
        "uk-UA",
        "vi-VN",
        "zh-CN",
        "zh-TW",
    ]

    number_formats = {}

    for lang in language_packs:
        url = csl_locale_base.format(ref=LOCALES_GIT_REF, lang=lang)

        print("Loading from " + url)
        with urllib.request.urlopen(url) as response:
            code = response.getcode()
            if code != 200:
                sys.stderr.write("Got {0} for {1}\n".format(code, url))
                return 1
            xml = ET.parse(response)

        # first, pull out all the translations for "edition", "editions", and "ed."
        edition_locators = set()
        for elem in xml.findall(".//{*}term[@name='edition']"):
            edition_locators.update(get_all_values(elem))
        edition_locators = list(edition_locators)

        # next, the translations for "-st", "-nd", "-rd", and "-th"
        short_ordinal_suffixes = set()
        for term in xml.findall(".//{*}term"):
            name = term.attrib.get("name", "")
            value = term.text
            if not (name.startswith("ordinal") and value):
                continue
            short_ordinal_suffixes.add(value)
            short_ordinal_suffixes.add(strip_superscript_chars(value))
        short_ordinal_suffixes = list(short_ordinal_suffixes)

        # lastly, the translations for "first", "second", "third", etc.
        long_ordinals = {}
        for term in xml.findall(".//{*}term"):
            name = term.attrib.get("name", "")
            if not name.startswith("long-ordinal-"):
                continue
            long_ordinals[term.text] = int(
                term.attrib.get("name", "").rsplit("-", 1)[1]
            )  # parse the "01" in "long-ordinal-01"

        number_formats[lang] = {
            "locators": {"edition": edition_locators},
            "ordinals": {"short": short_ordinal_suffixes, "long": long_ordinals},
        }

    number_formats[
        "credit"
    ] = f"Generated from the CSL locales repository <https://github.com/citation-style-language/locales/tree/{LOCALES_GIT_REF}> by https://github.com/zotero/zotero-build/blob/master/locale/merge_csl_locales"

    with open(os.path.join(schema_dir, "cslLocaleStrings.json"), "w") as outfile:
        json.dump(number_formats, outfile, ensure_ascii=False, indent='\t')
        print(f'Saved combined locales to {os.path.join(schema_dir, "cslLocaleStrings.json")}')


def get_all_values(elem):
    text = elem.text.strip()
    single = elem.findtext("{*}single")
    multiple = elem.findtext("{*}multiple")
    if text:
        yield text
    if single:
        yield single
    if multiple:
        yield multiple


def strip_superscript_chars(s):
    """Replace all Unicode superscript modifier characters in a string with their non-superscript
    counterparts and return the modified string."""
    output = []
    for c in s:
        decomposition = unicodedata.decomposition(c)
        if decomposition.startswith("<super> "):
            output.append(
                chr(int(unicodedata.decomposition(c)[len("<super> ") :], base=16))
            )
        else:
            output.append(c)
    return "".join(output)


if __name__ == "__main__":
    sys.exit(main())
