import pyboolector as bt
from preprocess import read_btor2_as_lines
from bidict import bidict

# Special Nodes:
def _process_sort_node(c: "Btor2Circuit", tokens: list[str]):
    # currently, only bit-vector and single-dimensional array are supported:
    sort = tokens[2]
    if sort == "array":
        node_id, _, _, index_sort_id, elem_sort_id = tokens
        index_sort = c.id_to_sort[int(index_sort_id)]
        if not isinstance(index_sort, bt._BoolectorBitVecSort):
            raise ValueError(f"Array index sort is not a bit-vector: {index_sort}")
        elem_sort = c.id_to_sort[int(elem_sort_id)]
        if not isinstance(elem_sort, bt._BoolectorBitVecSort):
            raise ValueError(f"Array element sort is not a bit-vector: {elem_sort}")
        sort = c.solver.ArraySort(index_sort, elem_sort)
    elif sort == "bitvec":
        node_id, _, _, length = tokens
        sort = c.solver.BitVecSort(int(length))
    else:
        raise ValueError(f"Unknown sort: {sort}")
    c.id_to_sort[int(node_id)] = sort

def _process_output_node(c: "Btor2Circuit", tokens: list[str]):  # does not create new node
    _, _, arg0_id, raw_name = tokens
    if raw_name == "valid":
        c.valid_signal_id = int(arg0_id)
    else:
        c.name_to_output_id[raw_name] = int(arg0_id)

def _process_state_node(c: "Btor2Circuit", tokens: list[str]):
    node_id, _, sort_id, raw_name = tokens
    sort = c.id_to_sort[int(sort_id)]
    node = c.solver.Var(sort, raw_name + c.suffix)
    c.id_to_node[int(node_id)] = node
    c.name_to_state_id[raw_name] = int(node_id)

def _process_input_node(c: "Btor2Circuit", tokens: list[str]):  # special state node (that never changes value)
    node_id, _, sort_id, raw_name = tokens
    sort = c.id_to_sort[int(sort_id)]
    node = c.solver.Var(sort, raw_name + c.suffix)
    c.id_to_node[int(node_id)] = node
    c.name_to_state_id[raw_name] = int(node_id)
    c.name_to_next_state_id[raw_name] = int(node_id)

def _process_init_node(c: "Btor2Circuit", tokens: list[str]):  # does not create new node
    _, _, _, state_id, value_id = tokens[:5]
    state_name = c.name_to_state_id.inverse[int(state_id)]
    c.name_to_init_state_id[state_name] = int(value_id)

def _process_next_node(c: "Btor2Circuit", tokens: list[str]):  # does not create new node
    _, _, _, state_id, value_id = tokens[:5]
    state_name = c.name_to_state_id.inverse[int(state_id)]
    c.name_to_next_state_id[state_name] = int(value_id)

def _process_const_node(c: "Btor2Circuit", tokens: list[str]):
    node_id, _, _, value = tokens
    node = c.solver.Const(value)
    c.id_to_node[int(node_id)] = node

# Indexing Nodes:
def _process_sext_node(c: "Btor2Circuit", tokens: list[str]):  # skip the node if length is 0
    node_id, _, _, arg0_id, length = tokens[:5]
    if int(length) == 0:
        return
    arg0 = c.id_to_node[int(arg0_id)]
    node = c.solver.Sext(arg0, int(length))
    c.id_to_node[int(node_id)] = node

def _process_uext_node(c: "Btor2Circuit", tokens: list[str]):  # skip the node if length is 0
    node_id, _, _, arg0_id, length = tokens[:5]
    if int(length) == 0:
        return
    arg0 = c.id_to_node[int(arg0_id)]
    node = c.solver.Uext(arg0, int(length))
    c.id_to_node[int(node_id)] = node

def _process_slice_node(c: "Btor2Circuit", tokens: list[str]):
    node_id, _, _, arg0_id, upper, lower = tokens[:6]
    arg0 = c.id_to_node[int(arg0_id)]
    node = c.solver.Slice(arg0, int(upper), int(lower))
    c.id_to_node[int(node_id)] = node

# Unary Nodes:
def _process_general_unary_node(c: "Btor2Circuit", tokens: list[str], operator):
    node_id, _, _, arg0_id = tokens[:4]
    arg0 = c.id_to_node[int(arg0_id)]
    node = operator(arg0)
    c.id_to_node[int(node_id)] = node

def _process_not_node(c, tokens): _process_general_unary_node(c, tokens, c.solver.Not)
def _process_inc_node(c, tokens): _process_general_unary_node(c, tokens, c.solver.Inc)
def _process_dec_node(c, tokens): _process_general_unary_node(c, tokens, c.solver.Dec)
def _process_neg_node(c, tokens): _process_general_unary_node(c, tokens, c.solver.Neg)
def _process_redand_node(c, tokens): _process_general_unary_node(c, tokens, c.solver.Redand)
def _process_redor_node(c, tokens): _process_general_unary_node(c, tokens, c.solver.Redor)
def _process_redxor_node(c, tokens): _process_general_unary_node(c, tokens, c.solver.Redxor)

# Binary Nodes:
def _process_general_binary_node(c: "Btor2Circuit", tokens: list[str], operator):
    node_id, _, _, arg0_id, arg1_id = tokens[:5]
    arg0 = c.id_to_node[int(arg0_id)]
    arg1 = c.id_to_node[int(arg1_id)]
    node = operator(arg0, arg1)
    c.id_to_node[int(node_id)] = node

## Boolean:
def _process_iff_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Iff)
def _process_implies_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Implies)
## Equality:
def _process_eq_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Eq)
def _process_neq_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Ne)
## Comparison:
def _process_sgt_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Sgt)
def _process_ugt_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Ugt)
def _process_sgte_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Sgte)
def _process_ugte_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Ugte)
def _process_slt_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Slt)
def _process_ult_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Ult)
def _process_slte_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Slte)
def _process_ulte_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Ulte)
## Bitwise:
def _process_and_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.And)
def _process_nand_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Nand)
def _process_nor_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Nor)
def _process_or_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Or)
def _process_xnor_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Xnor)
def _process_xor_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Xor)
## Shift:
def _process_rol_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Rol)
def _process_ror_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Ror)
def _process_sll_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Sll)
def _process_sra_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Sra)
def _process_srl_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Srl)
## Arithmetic:
def _process_add_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Add)
def _process_mul_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Mul)
def _process_sdiv_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Sdiv)
def _process_udiv_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Udiv)
def _process_smod_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Smod)
def _process_srem_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Srem)
def _process_urem_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Urem)
def _process_sub_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Sub)
## Overflow:
def _process_saddo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Saddo)
def _process_uaddo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Uaddo)
def _process_sdivo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Sdivo)
def _process_udivo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Udivo)
def _process_smulo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Smulo)
def _process_umulo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Umulo)
def _process_ssubo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Ssubo)
def _process_usubo_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Usubo)
## Concatenation:
def _process_concat_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Concat)
## Array:
def _process_read_node(c, tokens): _process_general_binary_node(c, tokens, c.solver.Read)

# Ternary Nodes:
def _process_ite_node(c: "Btor2Circuit", tokens: list[str]):
    node_id, _, _, arg0_id, arg1_id, arg2_id = tokens[:6]
    arg0 = c.id_to_node[int(arg0_id)]
    arg1 = c.id_to_node[int(arg1_id)]
    arg2 = c.id_to_node[int(arg2_id)]
    node = c.solver.Cond(arg0, arg1, arg2)
    c.id_to_node[int(node_id)] = node

def _process_write_node(c: "Btor2Circuit", tokens: list[str]):
    node_id, _, _, arg0_id, arg1_id, arg2_id = tokens[:6]
    arg0 = c.id_to_node[int(arg0_id)]
    arg1 = c.id_to_node[int(arg1_id)]
    arg2 = c.id_to_node[int(arg2_id)]
    node = c.solver.Write(arg0, arg1, arg2)
    c.id_to_node[int(node_id)] = node


_unsupported = {
    "constd", "consth", "one", "ones", "zero",
    "bad", "constraint", "fair", "justice",
}
"A set of unsupported node types."

_node_to_processor = {
    # Special Nodes:
    "sort": _process_sort_node,
    "output": _process_output_node,
    "state": _process_state_node,
    "input": _process_input_node,
    "init": _process_init_node,
    "next": _process_next_node,
    "const": _process_const_node,

    # Indexing Nodes:
    "sext": _process_sext_node,
    "uext": _process_uext_node,
    "slice": _process_slice_node,

    # Unary Nodes:
    "not": _process_not_node,
    "inc": _process_inc_node,
    "dec": _process_dec_node,
    "neg": _process_neg_node,
    "redand": _process_redand_node,
    "redor": _process_redor_node,
    "redxor": _process_redxor_node,

    # Binary Nodes:
    ## Boolean:
    "iff": _process_iff_node,
    "implies": _process_implies_node,
    ## Equality:
    "eq": _process_eq_node,
    "neq": _process_neq_node,
    ## Comparison:
    "sgt": _process_sgt_node,
    "ugt": _process_ugt_node,
    "sgte": _process_sgte_node,
    "ugte": _process_ugte_node,
    "slt": _process_slt_node,
    "ult": _process_ult_node,
    "slte": _process_slte_node,
    "ulte": _process_ulte_node,
    ## Bitwise:
    "and": _process_and_node,
    "nand": _process_nand_node,
    "nor": _process_nor_node,
    "or": _process_or_node,
    "xnor": _process_xnor_node,
    "xor": _process_xor_node,
    ## Shift:
    "rol": _process_rol_node,
    "ror": _process_ror_node,
    "sll": _process_sll_node,
    "sra": _process_sra_node,
    "srl": _process_srl_node,
    ## Arithmetic:
    "add": _process_add_node,
    "mul": _process_mul_node,
    "sdiv": _process_sdiv_node,
    "udiv": _process_udiv_node,
    "smod": _process_smod_node,
    "srem": _process_srem_node,
    "urem": _process_urem_node,
    "sub": _process_sub_node,
    ## Overflow:
    "saddo": _process_saddo_node,
    "uaddo": _process_uaddo_node,
    "sdivo": _process_sdivo_node,
    "udivo": _process_udivo_node,
    "smulo": _process_smulo_node,
    "umulo": _process_umulo_node,
    "ssubo": _process_ssubo_node,
    "usubo": _process_usubo_node,
    ## Concatenation:
    "concat": _process_concat_node,
    ## Array:
    "read": _process_read_node,

    # Ternary Nodes:
    "ite": _process_ite_node,
    "write": _process_write_node,
}
"Maps node names to corresponding processing functions."

def _process_node(c: "Btor2Circuit", line: str):
    "The main entry for processing a node in a line."
    tokens = line.split()
    node_type = tokens[1]
    if node_type in _node_to_processor:
        _node_to_processor[node_type](c, tokens)
    elif node_type in _unsupported:
        raise NotImplementedError(f"Unsupported node type: {node_type}")
    else:
        raise ValueError(f"Unknown node type: {node_type}")


class Btor2Circuit:
    """
    Keeps track of the nodes in a solver from a built circuit (the combinational part).
    All internal nodes are referenced by their node ID, hence the `int` type.

    NOTE: Constructor has side effect: it adds netlist to the solver.
    """
    solver: bt.Boolector
    "The solver this circuit belongs to. Different circuits may share a solver."
    suffix: str
    "The suffix added to leaf nodes, in order to differentiate copies of circuit"
    name_to_output_id: bidict[str, int]
    "Map between output names and internal nodes (primary output, except `valid`)"
    name_to_state_id: bidict[str, int]
    "Map between state names and internal nodes (secondary input)"
    name_to_next_state_id: bidict[str, int]
    "Map between state names and internal nodes (secondary output)"
    name_to_init_state_id: bidict[str, int]
    "Map between state names and internal `const` nodes (initial value, optional)"
    valid_signal_id: int
    "The valid signal"
    id_to_node: dict[int, bt.BoolectorNode]
    "Map between node IDs in the BTOR2 file and internal nodes (almost all)"
    id_to_sort: dict[int, bt.BoolectorSort]
    "Map between node IDs in the BTOR2 file and sorts"

    def __init__(self, solver: bt.Boolector, suffix: str, btor2_lines: list[str]):
        self.solver = solver
        self.suffix = suffix
        self.name_to_output_id = bidict()
        self.name_to_state_id = bidict()
        self.name_to_next_state_id = bidict()
        self.name_to_init_state_id = bidict()
        self.valid_signal_id = None
        self.id_to_node = {}
        self.id_to_sort = {}

        for line in btor2_lines:
            _process_node(self, line)

    def output_names(self):
        "Returns the names of the outputs (except `valid`)."
        return self.name_to_output_id.keys()

    def state_names(self):
        "Returns the names of the states."
        return self.name_to_state_id.keys()

    def output_by_name(self, name: str) -> bt.BoolectorNode:
        "Returns the internal output node with the given name."
        return self.id_to_node[self.name_to_output_id[name]]

    def valid_signal(self) -> bt.BoolectorNode:
        "Returns the internal valid signal."
        return self.id_to_node[self.valid_signal_id]

    def curr_state_by_name(self, name: str) -> bt.BoolectorNode:
        "Returns the internal current state node with the given name."
        return self.id_to_node[self.name_to_state_id[name]]

    def next_state_by_name(self, name: str) -> bt.BoolectorNode:
        "Returns the internal next state node with the given name."
        return self.id_to_node[self.name_to_next_state_id[name]]

    def init_state_by_name_or_none(self, name: str) -> bt.BoolectorNode:
        "Returns the internal initial state node with the given name, or `None`."
        if name in self.name_to_init_state_id:
            return self.id_to_node[self.name_to_init_state_id[name]]
        return None

    def state_id_to_width_dict(self) -> dict[int, int]:
        "Returns a dictionary mapping state (and init state) IDs to their bit-width."
        result = {}
        for state_id in self.name_to_state_id.values():
            result[state_id] = self.id_to_node[state_id].width
        for init_state_id in self.name_to_init_state_id.values():  # fix the initial state bug
            result[init_state_id] = self.id_to_node[init_state_id].width
        return result

    
def test():
    solver = bt.Boolector()
    solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
    solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)
    lines = read_btor2_as_lines("../data/gcd_toy/gcd_slow2x.btor2")
    circuit = Btor2Circuit(solver, "@0", lines)
    print(circuit.__dict__)
    print()

    solver.Assert(circuit.curr_state_by_name("x") == "0100000000000000")
    solver.Assert(circuit.curr_state_by_name("y") == "1000000000000000")
    solver.Sat()

    from btor2nodenameprinter import Btor2NodeNamePrinter
    printer = Btor2NodeNamePrinter(lines)
    for id, node in circuit.id_to_node.items():
        print(id, printer.get_str(id, ""), node.assignment)

if __name__ == "__main__":
    test()