from preprocess import read_btor2_as_lines

def _process_ios_node(node_to_str: dict[int, str], tokens: list[str]):
    "for `i`nputs, `o`utputs, and `s`tate nodes."
    node_id, _, _, symbol = tokens
    node_to_str[int(node_id)] = symbol

def _process_next_node(node_to_str: dict[int, str], tokens: list[str]):
    "for `next` nodes. (overriding the target node's name)"
    _, _, _, state_id, arg0_id = tokens
    state_name = node_to_str[int(state_id)]
    node_to_str[int(arg0_id)] = f"{state_name}'"  # override the name

def _process_const_node(node_to_str: dict[int, str], tokens: list[str]):
    "for `const` nodes."
    node_id, _, _, value = tokens
    node_to_str[int(node_id)] = value

def _process_ext_node(node_to_str: dict[int, str], tokens: list[str]):
    "for `sext` and `uext` nodes."
    node_id, op_name, _, arg0_id, length = tokens[:5]
    if length == "0":  # special case: no extension, just a rename
        if len(tokens) > 5:
            node_name = tokens[5]
            node_to_str[int(arg0_id)] = node_name  # override the name
    else:
        if len(tokens) > 5:
            node_name = tokens[5]
        else:
            arg0_name = node_to_str[int(arg0_id)]
            node_name = f"{op_name}_{length}({arg0_name})"
        node_to_str[int(node_id)] = node_name

def _process_slice_node(node_to_str: dict[int, str], tokens: list[str]):
    "for `slice` nodes."
    node_id, op_name, _, arg0_id, high, low = tokens[:6]
    if len(tokens) > 6:
        node_name = tokens[6]
    else:
        arg0_name = node_to_str[int(arg0_id)]
        node_name = f"{op_name}_{high}_{low}({arg0_name})"
    node_to_str[int(node_id)] = node_name

def _process_unary_node(node_to_str: dict[int, str], tokens: list[str]):
    "for all unary nodes."
    node_id, op_name, _, arg0_id = tokens[:4]
    if len(tokens) > 4:
        node_name = tokens[4]
    else:
        arg0_name = node_to_str[int(arg0_id)]
        node_name = f"{op_name}({arg0_name})"
    node_to_str[int(node_id)] = node_name

def _process_binary_node(node_to_str: dict[int, str], tokens: list[str]):
    node_id, op_name, _, arg0_id, arg1_id = tokens[:5]
    if len(tokens) > 5:
        node_name = tokens[5]
    else:
        arg0_name = node_to_str[int(arg0_id)]
        arg1_name = node_to_str[int(arg1_id)]
        node_name = f"{op_name}({arg0_name}, {arg1_name})"
    node_to_str[int(node_id)] = node_name

def _process_ternary_node(node_to_str: dict[int, str], tokens: list[str]):
    "for all ternary nodes."
    node_id, op_name, _, arg0_id, arg1_id, arg2_id = tokens[:6]
    if len(tokens) > 6:
        node_name = tokens[6]
    else:
        arg0_name = node_to_str[int(arg0_id)]
        arg1_name = node_to_str[int(arg1_id)]
        arg2_name = node_to_str[int(arg2_id)]
        node_name = f"{op_name}({arg0_name}, {arg1_name}, {arg2_name})"
    node_to_str[int(node_id)] = node_name

_unsupported = {
    "constd", "consth", "one", "ones", "zero",
    "bad", "constraint", "fair", "justice",
}
"A set of unsupported node types."

_skipped = {
    "sort", "init",
}
"A set of skipped node types."

_node_to_processor = {
    # IOS:
    "input": _process_ios_node,
    "output": _process_ios_node,
    "state": _process_ios_node,

    # NEXT:
    "next": _process_next_node,

    # CONST:
    "const": _process_const_node,

    # EXTENSION:
    "sext": _process_ext_node,
    "uext": _process_ext_node,

    # SLICING:
    "slice": _process_slice_node,

    # UNARY:
    "not": _process_unary_node,
    "inc": _process_unary_node,
    "dec": _process_unary_node,
    "neg": _process_unary_node,
    "redand": _process_unary_node,
    "redor": _process_unary_node,
    "redxor": _process_unary_node,

    # BINARY:
    ## Boolean:
    "iff": _process_binary_node,
    "implies": _process_binary_node,
    ## Equality:
    "eq": _process_binary_node,
    "neq": _process_binary_node,
    ## Comparison:
    "sgt": _process_binary_node,
    "ugt": _process_binary_node,
    "sgte": _process_binary_node,
    "ugte": _process_binary_node,
    "slt": _process_binary_node,
    "ult": _process_binary_node,
    "slte": _process_binary_node,
    "ulte": _process_binary_node,
    ## Bit-wise:
    "and": _process_binary_node,
    "nand": _process_binary_node,
    "nor": _process_binary_node,
    "or": _process_binary_node,
    "xnor": _process_binary_node,
    "xor": _process_binary_node,
    ## Shift:
    "rol": _process_binary_node,
    "ror": _process_binary_node,
    "sll": _process_binary_node,
    "sra": _process_binary_node,
    "srl": _process_binary_node,
    ## Arithmetic:
    "add": _process_binary_node,
    "mul": _process_binary_node,
    "sdiv": _process_binary_node,
    "udiv": _process_binary_node,
    "smod": _process_binary_node,
    "srem": _process_binary_node,
    "urem": _process_binary_node,
    "sub": _process_binary_node,
    ## Overflow:
    "saddo": _process_binary_node,
    "uaddo": _process_binary_node,
    "sdivo": _process_binary_node,
    "udivo": _process_binary_node,
    "smulo": _process_binary_node,
    "umulo": _process_binary_node,
    "ssubo": _process_binary_node,
    "usubo": _process_binary_node,
    ## Concatenation:
    "concat": _process_binary_node,
    ## Array:
    "read": _process_binary_node,

    # TERNARY:
    "ite": _process_ternary_node,
    "write": _process_ternary_node,
}
"Maps node names to corresponding processing functions."

def _process_node(node_to_str: dict[int, str], line: str):
    "The main entry for processing a node in a line."
    tokens = line.split()
    node_type = tokens[1]
    if node_type in _node_to_processor:
        _node_to_processor[node_type](node_to_str, tokens)
    elif node_type in _skipped:
        pass
    elif node_type in _unsupported:
        raise NotImplementedError(f"Unsupported node type: {node_type}")
    else:
        raise ValueError(f"Unknown node type: {node_type}")



class Btor2NodeNamePrinter:
    "A class for printing node names in a BTOR2 file."

    lines: list[str]
    "A list of comment-less stripped lines in the BTOR2 file."
    node_to_str: dict[int, str]
    "A map from node ID to its string representation."

    def __init__(self, lines: list[str]):
        self.lines = lines
        self.node_to_str = {}
        for line in self.lines:
            _process_node(self.node_to_str, line)

    def get_str(self, node_id: int, suffix: str) -> str:
        "Returns the string representation of the node, given ID and suffix."
        return self.node_to_str[node_id] + suffix


def test():
    lines = read_btor2_as_lines("../data/gcd_toy/gcd_slow2x.btor2")
    printer = Btor2NodeNamePrinter(lines)
    for i in range(1, 26):
        if i in printer.node_to_str:
            print(f"{i} = {printer.get_str(i, '@0')}")

if __name__ == "__main__":
    test()