Module tourniquet.tourniquet
Expand source code
import shutil
import subprocess
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Iterator, Optional
from . import extractor, models
from .error import PatchSituationError, TemplateNameError
from .location import Location, SourceCoordinate
from .patch_lang import PatchTemplate
class Tourniquet:
def __init__(self, database_name):
self.db_name = database_name
self.db = models.DB.create(database_name)
self.patch_templates: Dict[str, PatchTemplate] = {}
def _path_looks_like_cxx(self, source_path: Path):
return source_path.suffix in [".cpp", ".cc", ".cxx"]
def _extract_ast(self, source_path: Path, is_cxx: bool = True) -> Dict[str, Any]:
if not source_path.is_file():
raise FileNotFoundError(f"{source_path} is not a file")
return extractor.extract_ast(source_path, is_cxx)
def _store_ast(self, ast_info: Dict[str, Any]):
module = models.Module(name=ast_info["module_name"])
for global_ in ast_info["globals"]:
assert global_[0] == "var_type", f"{global_[0]} != var_type"
global_ = models.Global(
module=module,
name=global_[5],
type_=global_[6],
start_line=global_[1],
start_column=global_[2],
end_line=global_[3],
end_column=global_[4],
is_array=bool(global_[7]),
size=global_[8],
)
self.db.session.add(global_)
# Every subsequent member
for func_name, exprs in ast_info["functions"].items():
# NOTE(ww): We expected the first member of each function's list to be
# a "func_decl" list, containing information about the function declaration
# itself. We use this to construct the initial Function model.
# If a list doesn't begin with "func_decl," then it was external and we
# skip it.
# TODO(ww): Think more about the above.
exprs = iter(exprs)
func_decl = next(exprs)
if func_decl[0] != "func_decl":
continue
function = models.Function(
module=module,
name=func_name,
start_line=func_decl[1],
start_column=func_decl[2],
end_line=func_decl[3],
end_column=func_decl[4],
)
self.db.session.add(function)
for expr in exprs:
# From here, the exprs we know are "var_type" (models.VarDecl),
# "call_type" (models.Call), and "stmt_type" (models.Statement).
# "call_type" lists contain, in turn, a list of arguments,
# which we promote to models.Argument objects.
if expr[0] == "var_type":
var_decl = models.VarDecl(
function=function,
name=expr[5],
type_=expr[6],
start_line=expr[1],
start_column=expr[2],
end_line=expr[3],
end_column=expr[4],
is_array=bool(expr[7]),
size=expr[8],
)
self.db.session.add(var_decl)
elif expr[0] == "call_type":
call = models.Call(
module=module,
function=function,
expr=expr[5],
name=expr[6],
start_line=expr[1],
start_column=expr[2],
end_line=expr[3],
end_column=expr[4],
)
self.db.session.add(call)
for name, type_ in expr[7:]:
argument = models.Argument(call=call, name=name, type_=type_)
self.db.session.add(argument)
elif expr[0] == "stmt_type":
stmt = models.Statement(
module=module,
function=function,
expr=expr[5],
start_line=expr[1],
start_column=expr[2],
end_line=expr[3],
end_column=expr[4],
)
self.db.session.add(stmt)
else:
assert False, expr[0]
self.db.session.commit()
# TODO Should take a target
def collect_info(self, source_path: Path):
"""
Collect information about the given source file and add it to the backing database.
"""
ast_info = self._extract_ast(source_path, is_cxx=self._path_looks_like_cxx(source_path))
self._store_ast(ast_info)
def register_template(self, name: str, template: PatchTemplate):
"""
Register a patching template with the given name.
"""
if name in self.patch_templates:
raise TemplateNameError(f"a template has already been registered as {name}")
self.patch_templates[name] = template
# TODO Should take target
# TODO(ww): Consider rehoming this?
def view_template(self, template_name, location: Location) -> Optional[str]:
"""
Pretty-print the given template, partially concretized to the given
module and source location.
"""
template = self.patch_templates.get(template_name)
if template is None:
return None
print("=" * 10, template_name, "=" * 10)
view_str = template.view(self.db, location)
print(view_str)
print("=" * 10, "END", "=" * 10)
return view_str
# TODO Should take a target
def concretize_template(self, template_name: str, location: Location) -> Iterator[str]:
"""
Concretize the given registered template to the given
module and source location, yielding each candidate patch.
Args:
template_name: The name of the template to concretize. This name
must have been previously registered with `register_template`.
location: The `Location` to concretize the template at.
Returns:
A generator of strings, each one representing a concrete patch suitable
for placement at the supplied location.
Raises:
TemplateNameError: If the supplied template name isn't registered.
"""
template = self.patch_templates.get(template_name)
if template is None:
raise TemplateNameError(f"no template registed with name {template_name}")
yield from template.concretize(self.db, location)
# TODO Should take a target
# TODO(ww): This should take a span instead of a location, so that it doesn't have
# to depend on the patch location being a statement.
@contextmanager
def patch(self, replacement: str, location: Location) -> Iterator[bool]:
"""
Applies the given replacement to the given location.
Rolls back the replacement after context closure.
Args:
replacement: The patch to insert.
location: The `Location` to insert at, including the source file.
Returns:
A generator whose single yield is the status of the replacement operation.
Raises:
PatchSituationError: If the supplied location can't be used for a patch.
"""
try:
temp_file = tempfile.NamedTemporaryFile()
shutil.copyfile(location.filename, temp_file.name)
statement = self.db.statement_at(location)
if statement is None:
raise PatchSituationError(f"no statement at ({location.line}, {location.column})")
yield self.transform(
location.filename,
replacement,
statement.start_coordinate,
statement.end_coordinate,
)
finally:
shutil.copyfile(temp_file.name, location.filename)
temp_file.close()
# TODO Should take a target
def auto_patch(self, template_name, tests, location: Location) -> Optional[str]:
# TODO(ww): This should be a NamedTempFile, at the absolute minimum.
EXEC_FILE = Path("/tmp/target")
# Collect replacements
replacements = self.concretize_template(template_name, location)
# Patch
for replacement in replacements:
with self.patch(replacement, location):
# Just compile with clang for now
ret = subprocess.call(["clang", "-g", "-o", EXEC_FILE, location.filename])
if ret != 0:
print("Error, build failed?")
continue
# Run the test suite
failed_test = False
for test_case in tests:
(input_, output) = test_case
ret = subprocess.call([EXEC_FILE, input_])
if output != ret:
failed_test = True
break
if not failed_test:
# This means that its fixed :)
return replacement
return None
def transform(
self, filename: Path, replacement: str, start: SourceCoordinate, end: SourceCoordinate
):
res = extractor.transform(
filename,
self._path_looks_like_cxx(filename),
replacement,
start.line,
start.column,
end.line,
end.column,
)
return res
Classes
class Tourniquet (database_name)
-
Expand source code
class Tourniquet: def __init__(self, database_name): self.db_name = database_name self.db = models.DB.create(database_name) self.patch_templates: Dict[str, PatchTemplate] = {} def _path_looks_like_cxx(self, source_path: Path): return source_path.suffix in [".cpp", ".cc", ".cxx"] def _extract_ast(self, source_path: Path, is_cxx: bool = True) -> Dict[str, Any]: if not source_path.is_file(): raise FileNotFoundError(f"{source_path} is not a file") return extractor.extract_ast(source_path, is_cxx) def _store_ast(self, ast_info: Dict[str, Any]): module = models.Module(name=ast_info["module_name"]) for global_ in ast_info["globals"]: assert global_[0] == "var_type", f"{global_[0]} != var_type" global_ = models.Global( module=module, name=global_[5], type_=global_[6], start_line=global_[1], start_column=global_[2], end_line=global_[3], end_column=global_[4], is_array=bool(global_[7]), size=global_[8], ) self.db.session.add(global_) # Every subsequent member for func_name, exprs in ast_info["functions"].items(): # NOTE(ww): We expected the first member of each function's list to be # a "func_decl" list, containing information about the function declaration # itself. We use this to construct the initial Function model. # If a list doesn't begin with "func_decl," then it was external and we # skip it. # TODO(ww): Think more about the above. exprs = iter(exprs) func_decl = next(exprs) if func_decl[0] != "func_decl": continue function = models.Function( module=module, name=func_name, start_line=func_decl[1], start_column=func_decl[2], end_line=func_decl[3], end_column=func_decl[4], ) self.db.session.add(function) for expr in exprs: # From here, the exprs we know are "var_type" (models.VarDecl), # "call_type" (models.Call), and "stmt_type" (models.Statement). # "call_type" lists contain, in turn, a list of arguments, # which we promote to models.Argument objects. if expr[0] == "var_type": var_decl = models.VarDecl( function=function, name=expr[5], type_=expr[6], start_line=expr[1], start_column=expr[2], end_line=expr[3], end_column=expr[4], is_array=bool(expr[7]), size=expr[8], ) self.db.session.add(var_decl) elif expr[0] == "call_type": call = models.Call( module=module, function=function, expr=expr[5], name=expr[6], start_line=expr[1], start_column=expr[2], end_line=expr[3], end_column=expr[4], ) self.db.session.add(call) for name, type_ in expr[7:]: argument = models.Argument(call=call, name=name, type_=type_) self.db.session.add(argument) elif expr[0] == "stmt_type": stmt = models.Statement( module=module, function=function, expr=expr[5], start_line=expr[1], start_column=expr[2], end_line=expr[3], end_column=expr[4], ) self.db.session.add(stmt) else: assert False, expr[0] self.db.session.commit() # TODO Should take a target def collect_info(self, source_path: Path): """ Collect information about the given source file and add it to the backing database. """ ast_info = self._extract_ast(source_path, is_cxx=self._path_looks_like_cxx(source_path)) self._store_ast(ast_info) def register_template(self, name: str, template: PatchTemplate): """ Register a patching template with the given name. """ if name in self.patch_templates: raise TemplateNameError(f"a template has already been registered as {name}") self.patch_templates[name] = template # TODO Should take target # TODO(ww): Consider rehoming this? def view_template(self, template_name, location: Location) -> Optional[str]: """ Pretty-print the given template, partially concretized to the given module and source location. """ template = self.patch_templates.get(template_name) if template is None: return None print("=" * 10, template_name, "=" * 10) view_str = template.view(self.db, location) print(view_str) print("=" * 10, "END", "=" * 10) return view_str # TODO Should take a target def concretize_template(self, template_name: str, location: Location) -> Iterator[str]: """ Concretize the given registered template to the given module and source location, yielding each candidate patch. Args: template_name: The name of the template to concretize. This name must have been previously registered with `register_template`. location: The `Location` to concretize the template at. Returns: A generator of strings, each one representing a concrete patch suitable for placement at the supplied location. Raises: TemplateNameError: If the supplied template name isn't registered. """ template = self.patch_templates.get(template_name) if template is None: raise TemplateNameError(f"no template registed with name {template_name}") yield from template.concretize(self.db, location) # TODO Should take a target # TODO(ww): This should take a span instead of a location, so that it doesn't have # to depend on the patch location being a statement. @contextmanager def patch(self, replacement: str, location: Location) -> Iterator[bool]: """ Applies the given replacement to the given location. Rolls back the replacement after context closure. Args: replacement: The patch to insert. location: The `Location` to insert at, including the source file. Returns: A generator whose single yield is the status of the replacement operation. Raises: PatchSituationError: If the supplied location can't be used for a patch. """ try: temp_file = tempfile.NamedTemporaryFile() shutil.copyfile(location.filename, temp_file.name) statement = self.db.statement_at(location) if statement is None: raise PatchSituationError(f"no statement at ({location.line}, {location.column})") yield self.transform( location.filename, replacement, statement.start_coordinate, statement.end_coordinate, ) finally: shutil.copyfile(temp_file.name, location.filename) temp_file.close() # TODO Should take a target def auto_patch(self, template_name, tests, location: Location) -> Optional[str]: # TODO(ww): This should be a NamedTempFile, at the absolute minimum. EXEC_FILE = Path("/tmp/target") # Collect replacements replacements = self.concretize_template(template_name, location) # Patch for replacement in replacements: with self.patch(replacement, location): # Just compile with clang for now ret = subprocess.call(["clang", "-g", "-o", EXEC_FILE, location.filename]) if ret != 0: print("Error, build failed?") continue # Run the test suite failed_test = False for test_case in tests: (input_, output) = test_case ret = subprocess.call([EXEC_FILE, input_]) if output != ret: failed_test = True break if not failed_test: # This means that its fixed :) return replacement return None def transform( self, filename: Path, replacement: str, start: SourceCoordinate, end: SourceCoordinate ): res = extractor.transform( filename, self._path_looks_like_cxx(filename), replacement, start.line, start.column, end.line, end.column, ) return res
Methods
def auto_patch(self, template_name, tests, location: Location) ‑> Optional[str]
-
Expand source code
def auto_patch(self, template_name, tests, location: Location) -> Optional[str]: # TODO(ww): This should be a NamedTempFile, at the absolute minimum. EXEC_FILE = Path("/tmp/target") # Collect replacements replacements = self.concretize_template(template_name, location) # Patch for replacement in replacements: with self.patch(replacement, location): # Just compile with clang for now ret = subprocess.call(["clang", "-g", "-o", EXEC_FILE, location.filename]) if ret != 0: print("Error, build failed?") continue # Run the test suite failed_test = False for test_case in tests: (input_, output) = test_case ret = subprocess.call([EXEC_FILE, input_]) if output != ret: failed_test = True break if not failed_test: # This means that its fixed :) return replacement return None
def collect_info(self, source_path: pathlib.Path)
-
Collect information about the given source file and add it to the backing database.
Expand source code
def collect_info(self, source_path: Path): """ Collect information about the given source file and add it to the backing database. """ ast_info = self._extract_ast(source_path, is_cxx=self._path_looks_like_cxx(source_path)) self._store_ast(ast_info)
def concretize_template(self, template_name: str, location: Location) ‑> Iterator[str]
-
Concretize the given registered template to the given module and source location, yielding each candidate patch.
Args
template_name
- The name of the template to concretize. This name
must have been previously registered with
register_template
. location
- The
Location
to concretize the template at.
Returns
A generator of strings, each one representing a concrete patch suitable for placement at the supplied location.
Raises
TemplateNameError
- If the supplied template name isn't registered.
Expand source code
def concretize_template(self, template_name: str, location: Location) -> Iterator[str]: """ Concretize the given registered template to the given module and source location, yielding each candidate patch. Args: template_name: The name of the template to concretize. This name must have been previously registered with `register_template`. location: The `Location` to concretize the template at. Returns: A generator of strings, each one representing a concrete patch suitable for placement at the supplied location. Raises: TemplateNameError: If the supplied template name isn't registered. """ template = self.patch_templates.get(template_name) if template is None: raise TemplateNameError(f"no template registed with name {template_name}") yield from template.concretize(self.db, location)
def patch(self, replacement: str, location: Location) ‑> Iterator[bool]
-
Applies the given replacement to the given location.
Rolls back the replacement after context closure.
Args
replacement
- The patch to insert.
location
- The
Location
to insert at, including the source file.
Returns
A generator whose single yield is the status of the replacement operation.
Raises
PatchSituationError
- If the supplied location can't be used for a patch.
Expand source code
@contextmanager def patch(self, replacement: str, location: Location) -> Iterator[bool]: """ Applies the given replacement to the given location. Rolls back the replacement after context closure. Args: replacement: The patch to insert. location: The `Location` to insert at, including the source file. Returns: A generator whose single yield is the status of the replacement operation. Raises: PatchSituationError: If the supplied location can't be used for a patch. """ try: temp_file = tempfile.NamedTemporaryFile() shutil.copyfile(location.filename, temp_file.name) statement = self.db.statement_at(location) if statement is None: raise PatchSituationError(f"no statement at ({location.line}, {location.column})") yield self.transform( location.filename, replacement, statement.start_coordinate, statement.end_coordinate, ) finally: shutil.copyfile(temp_file.name, location.filename) temp_file.close()
def register_template(self, name: str, template: PatchTemplate)
-
Register a patching template with the given name.
Expand source code
def register_template(self, name: str, template: PatchTemplate): """ Register a patching template with the given name. """ if name in self.patch_templates: raise TemplateNameError(f"a template has already been registered as {name}") self.patch_templates[name] = template
def transform(self, filename: pathlib.Path, replacement: str, start: SourceCoordinate, end: SourceCoordinate)
-
Expand source code
def transform( self, filename: Path, replacement: str, start: SourceCoordinate, end: SourceCoordinate ): res = extractor.transform( filename, self._path_looks_like_cxx(filename), replacement, start.line, start.column, end.line, end.column, ) return res
def view_template(self, template_name, location: Location) ‑> Optional[str]
-
Pretty-print the given template, partially concretized to the given module and source location.
Expand source code
def view_template(self, template_name, location: Location) -> Optional[str]: """ Pretty-print the given template, partially concretized to the given module and source location. """ template = self.patch_templates.get(template_name) if template is None: return None print("=" * 10, template_name, "=" * 10) view_str = template.view(self.db, location) print(view_str) print("=" * 10, "END", "=" * 10) return view_str