Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 206 additions & 21 deletions src/ida_pro_mcp/mcp-plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys

if sys.version_info < (3, 11):
Expand Down Expand Up @@ -249,6 +250,10 @@ def _run_server(self):
import ida_xref
import ida_entry
import idautils
import ida_idd
import ida_dbg
import ida_name
import ida_ida

class IDAError(Exception):
def __init__(self, message: str):
Expand Down Expand Up @@ -508,7 +513,7 @@ def get_function_by_name(
@jsonrpc
@idaread
def get_function_by_address(
address: Annotated[str, "Address of the function to get"]
address: Annotated[str, "Address of the function to get"],
) -> Function:
"""Get a function by its address"""
return get_function(parse_address(address))
Expand Down Expand Up @@ -573,7 +578,7 @@ def convert_number(
"hexadecimal": hex(value),
"bytes": bytes.hex(" "),
"ascii": ascii,
"binary": bin(value)
"binary": bin(value),
}

T = TypeVar("T")
Expand All @@ -589,7 +594,7 @@ def paginate(data: list[T], offset: int, count: int) -> Page[T]:
if next_offset >= len(data):
next_offset = None
return {
"data": data[offset:offset+count],
"data": data[offset:offset + count],
"next_offset": next_offset,
}

Expand Down Expand Up @@ -620,7 +625,7 @@ def get_strings() -> list[String]:
"address": hex(item.ea),
"length": item.length,
"type": string_type,
"string": string
"string": string,
})
except:
continue
Expand All @@ -639,9 +644,9 @@ def list_strings(
@jsonrpc
@idaread
def search_strings(
pattern_str: Annotated[str, "The regular expression to match((The generated regular expression includes case by default))"],
offset: Annotated[int, "Offset to start listing from (start at 0)"],
count: Annotated[int, "Number of strings to list (100 is a good default, 0 means remainder)"],
pattern_str: Annotated[str, "The regular expression to match((The generated regular expression includes case by default))"],
offset: Annotated[int, "Offset to start listing from (start at 0)"],
count: Annotated[int, "Number of strings to list (100 is a good default, 0 means remainder)"],
) -> Page[String]:
"""Search for strings that satisfy a regular expression"""
strings = get_strings()
Expand All @@ -667,8 +672,6 @@ def search_strings(
matched_strings = [s for s in strings if pattern.lower() in s["string"].lower()]
return paginate(matched_strings, offset, count)



def decompile_checked(address: int) -> ida_hexrays.cfunc_t:
if not ida_hexrays.init_hexrays_plugin():
raise IDAError("Hex-Rays decompiler is not available")
Expand All @@ -686,7 +689,7 @@ def decompile_checked(address: int) -> ida_hexrays.cfunc_t:
@jsonrpc
@idaread
def decompile_function(
address: Annotated[str, "Address of the function to decompile"]
address: Annotated[str, "Address of the function to decompile"],
) -> str:
"""Decompile a function at the given address"""
address = parse_address(address)
Expand Down Expand Up @@ -719,7 +722,7 @@ def decompile_function(
@jsonrpc
@idaread
def disassemble_function(
start_address: Annotated[str, "Address of the function to disassemble"]
start_address: Annotated[str, "Address of the function to disassemble"],
) -> str:
"""Get assembly code (address: instruction; comment) for a function"""
start = parse_address(start_address)
Expand Down Expand Up @@ -751,7 +754,7 @@ class Xref(TypedDict):
@jsonrpc
@idaread
def get_xrefs_to(
address: Annotated[str, "Address to get cross references to"]
address: Annotated[str, "Address to get cross references to"],
) -> list[Xref]:
"""Get all cross references to the given address"""
xrefs = []
Expand Down Expand Up @@ -781,7 +784,7 @@ def get_entry_points() -> list[Function]:
@idawrite
def set_comment(
address: Annotated[str, "Address in the function to set the comment for"],
comment: Annotated[str, "Comment text"]
comment: Annotated[str, "Comment text"],
):
"""Set a comment for a given address in the function disassembly and pseudocode"""
address = parse_address(address)
Expand Down Expand Up @@ -842,7 +845,7 @@ def refresh_decompiler_ctext(function_address: int):
def rename_local_variable(
function_address: Annotated[str, "Address of the function containing the variable"],
old_name: Annotated[str, "Current name of the variable"],
new_name: Annotated[str, "New name for the variable (empty for a default name)"]
new_name: Annotated[str, "New name for the variable (empty for a default name)"],
):
"""Rename a local variable in a function"""
func = idaapi.get_func(parse_address(function_address))
Expand All @@ -856,7 +859,7 @@ def rename_local_variable(
@idawrite
def rename_global_variable(
old_name: Annotated[str, "Current name of the global variable"],
new_name: Annotated[str, "New name for the global variable (empty for a default name)"]
new_name: Annotated[str, "New name for the global variable (empty for a default name)"],
):
"""Rename a global variable"""
ea = idaapi.get_name_ea(idaapi.BADADDR, old_name)
Expand All @@ -868,7 +871,7 @@ def rename_global_variable(
@idawrite
def set_global_variable_type(
variable_name: Annotated[str, "Name of the global variable"],
new_type: Annotated[str, "New type for the variable"]
new_type: Annotated[str, "New type for the variable"],
):
"""Set a global variable's type"""
ea = idaapi.get_name_ea(idaapi.BADADDR, variable_name)
Expand All @@ -882,7 +885,7 @@ def set_global_variable_type(
@idawrite
def rename_function(
function_address: Annotated[str, "Address of the function to rename"],
new_name: Annotated[str, "New name for the function (empty for a default name)"]
new_name: Annotated[str, "New name for the function (empty for a default name)"],
):
"""Rename a function"""
func = idaapi.get_func(parse_address(function_address))
Expand All @@ -896,7 +899,7 @@ def rename_function(
@idawrite
def set_function_prototype(
function_address: Annotated[str, "Address of the function"],
prototype: Annotated[str, "New function prototype"]
prototype: Annotated[str, "New function prototype"],
) -> str:
"""Set a function's prototype"""
func = idaapi.get_func(parse_address(function_address))
Expand Down Expand Up @@ -930,15 +933,22 @@ def modify_lvars(self, lvars):
def parse_decls_ctypes(decls: str, hti_flags: int) -> tuple[int, str]:
if sys.platform == "win32":
import ctypes

assert isinstance(decls, str), "decls must be a string"
assert isinstance(hti_flags, int), "hti_flags must be an int"
c_decls = decls.encode("utf-8")
c_til = None
ida_dll = ctypes.CDLL("ida")
ida_dll.parse_decls.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_int]
ida_dll.parse_decls.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.c_void_p,
ctypes.c_int,
]
ida_dll.parse_decls.restype = ctypes.c_int

messages = []

@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p)
def magic_printer(fmt: bytes, arg1: bytes):
if fmt.count(b"%") == 1 and b"%s" in fmt:
Expand Down Expand Up @@ -979,7 +989,7 @@ def declare_c_type(
def set_local_variable_type(
function_address: Annotated[str, "Address of the function containing the variable"],
variable_name: Annotated[str, "Name of the variable"],
new_type: Annotated[str, "New type for the variable"]
new_type: Annotated[str, "New type for the variable"],
):
"""Set a local variable's type"""
try:
Expand All @@ -989,7 +999,7 @@ def set_local_variable_type(
try:
new_tif = ida_typeinf.tinfo_t()
# parse_decl requires semicolon for the type
ida_typeinf.parse_decl(new_tif, None, new_type+";", ida_typeinf.PT_SIL)
ida_typeinf.parse_decl(new_tif, None, new_type + ";", ida_typeinf.PT_SIL)
except Exception:
raise IDAError(f"Failed to parse type: {new_type}")
func = idaapi.get_func(parse_address(function_address))
Expand All @@ -1002,6 +1012,181 @@ def set_local_variable_type(
raise IDAError(f"Failed to modify local variable: {variable_name}")
refresh_decompiler_ctext(func.start_ea)

@jsonrpc
@idaread
def dbg_get_registers() -> list[dict[str, str]]:
"""Get all registers and their values. This function is only available when debugging."""
result = []
dbg = ida_idd.get_dbg()
# TODO: raise an exception when not debugging?
for thread_index in range(ida_dbg.get_thread_qty()):
tid = ida_dbg.getn_thread(thread_index)
regs = []
regvals = ida_dbg.get_reg_vals(tid)
for reg_index, rv in enumerate(regvals):
reg_info = dbg.regs(reg_index)
reg_value = rv.pyval(reg_info.dtype)
if isinstance(reg_value, int):
reg_value = hex(reg_value)
if isinstance(reg_value, bytes):
reg_value = reg_value.hex(" ")
regs.append({
"name": reg_info.name,
"value": reg_value,
})
result.append({
"thread_id": tid,
"registers": regs,
})
return result

@jsonrpc
@idaread
def dbg_get_call_stack() -> list[dict[str, str]]:
"""Get the current call stack."""
callstack = []
try:
tid = ida_dbg.get_current_thread()
trace = ida_idd.call_stack_t()

if not ida_dbg.collect_stack_trace(tid, trace):
return []
for frame in trace:
frame_info = {
"address": hex(frame.callea),
}
try:
module_info = ida_idd.modinfo_t()
if ida_dbg.get_module_info(frame.callea, module_info):
frame_info["module"] = os.path.basename(module_info.name)
else:
frame_info["module"] = "<unknown>"

name = (
ida_name.get_nice_colored_name(
frame.callea,
ida_name.GNCN_NOCOLOR
| ida_name.GNCN_NOLABEL
| ida_name.GNCN_NOSEG
| ida_name.GNCN_PREFDBG,
)
or "<unnamed>"
)
frame_info["symbol"] = name

except Exception as e:
frame_info["module"] = "<error>"
frame_info["symbol"] = str(e)

callstack.append(frame_info)

except Exception as e:
pass
return callstack

@jsonrpc
@idaread
def dbg_list_breakpoints():
"""
List all breakpoints in the program.
"""
ea = ida_ida.inf_get_min_ea()
end_ea = ida_ida.inf_get_max_ea()
bkpts = []
while ea <= end_ea:
bpt = ida_dbg.bpt_t()
if ida_dbg.get_bpt(ea, bpt):
bkpts.append(
{
"ea": hex(bpt.ea),
"type": bpt.type,
"enabled": bpt.flags & ida_dbg.BPT_ENABLED,
"condition": bpt.condition if bpt.condition else None,
}
)
ea = ida_bytes.next_head(ea, end_ea)
return bkpts

@jsonrpc
@idaread
def dbg_start_process() -> str:
"""Start the debugger"""
ret = idaapi.start_process("", "", "")
if ret == 1:
return "Debugger started"
return "Failed to start debugger"

@jsonrpc
@idaread
def dbg_exit_process() -> str:
"""Exit the debugger"""
ret = idaapi.exit_process()
if ret == 1:
return "Debugger exited"
return "Failed to exit debugger"

@jsonrpc
@idaread
def dbg_continue_process() -> str:
"""Continue the debugger"""
ret = idaapi.continue_process()
if ret == 1:
return "Debugger continued"
return "Failed to continue debugger"

@jsonrpc
@idaread
def dbg_run_to(
address: Annotated[str, "Run the debugger to the specified address"],
) -> str:
"""Run the debugger to the specified address"""
ea = parse_address(address)
ret = idaapi.run_to(ea)
if ret == 1:
return f"Debugger run to {hex(ea)}"
return f"Failed to run to address {hex(ea)}"

@jsonrpc
@idaread
def dbg_set_breakpoint(
address: Annotated[str, "Set a breakpoint at the specified address"],
) -> str:
"""Set a breakpoint at the specified address"""
ea = parse_address(address)
ret = idaapi.add_bpt(ea, 0, idaapi.BPT_SOFT)
if ret == 1:
return f"Breakpoint set at {hex(ea)}"
bpts = dbg_list_breakpoints()
for bpt in bpts:
if bpt["ea"] == hex(ea):
return f"Breakpoint already exists at {hex(ea)}"
return f"Failed to set breakpoint at address {hex(ea)}"

@jsonrpc
@idaread
def dbg_delete_breakpoint(
address: Annotated[str, "del a breakpoint at the specified address"],
) -> str:
"""del a breakpoint at the specified address"""
ea = parse_address(address)
ret = idaapi.del_bpt(ea)
if ret == 1:
return f"Breakpoint deleted at {hex(ea)}"
return f"Failed to delete breakpoint at address {hex(ea)}"

@jsonrpc
@idaread
def dbg_enable_breakpoint(
address: Annotated[str, "Enable or disable a breakpoint at the specified address"],
enable: Annotated[bool, "Enable or disable a breakpoint"],
) -> str:
"""Enable or disable a breakpoint at the specified address"""
ea = parse_address(address)
ret = idaapi.enable_bpt(ea, enable)
if ret == 1:
return f"Breakpoint {'enabled' if enable else 'disabled'} at {hex(ea)}"
return f"Failed to {'' if enable else 'disable '}breakpoint at address {hex(ea)}"

class MCP(idaapi.plugin_t):
flags = idaapi.PLUGIN_KEEP
comment = "MCP Plugin"
Expand Down