ArKanjo 0.2
A tool for find code duplicated functions in codebases
Loading...
Searching...
No Matches
tree_sitter_parser.cpp
Go to the documentation of this file.
2#include <iostream>
3#include <unordered_map>
8
9std::unordered_map<std::string, TSLanguage* (*)()> get_language_map();
10std::unordered_map<std::string, std::string> get_extension_map();
11
12bool TreeSitterParser::is_function_empty(TSNode body) {
13 if (ts_node_is_null(body)) return true;
14
15 uint32_t n = ts_node_named_child_count(body);
16
17 return n == 0;
18}
19
20std::string TreeSitterParser::detect_language(const std::string& path) {
21 static auto ext_map = get_extension_map();
22
23 std::string ext = fs::path(path).extension().string();
24
25 auto it = ext_map.find(ext);
26 if (it != ext_map.end()) {
27 return it->second;
28 }
29
30 return "";
31}
32
33std::string TreeSitterParser::get_full_signature(TSNode func_node, const std::string& source) {
34 uint32_t signature_start = ts_node_start_byte(func_node);
35
36 TSNode body = get_body(func_node);
37
38 uint32_t signature_end = ts_node_start_byte(body);
39
40 return source.substr(signature_start, signature_end - signature_start);
41}
42
43static const char* field_names[] = {
44 "name", "declarator"
45};
46
47std::string extract_name_generic(TSNode node, const std::string& source) {
48 for (const auto& field_name : field_names) {
49 TSNode name = ts_node_child_by_field_name(node, field_name, strlen(field_name));
50 if (!ts_node_is_null(name)) {
51 std::string_view type = ts_node_type(name);
52 if (type == "identifier") {
53 return FeatureExtractor::get_node_text(name, source);
54 }
55
56 if (type == "qualified_identifier") {
57 TSNode qualified_identifier_name = ts_node_child_by_field_name(name, "name", strlen("name"));
58 if (!ts_node_is_null(qualified_identifier_name)) {
59 return source.substr(
60 ts_node_start_byte(qualified_identifier_name),
61 ts_node_end_byte(qualified_identifier_name) - ts_node_start_byte(qualified_identifier_name)
62 );
63 }
64 }
65 }
66 }
67
68 uint32_t count = ts_node_child_count(node);
69 for (uint32_t i = 0; i < count; i++) {
70 TSNode child = ts_node_child(node, i);
71 std::string result = extract_name_generic(child, source);
72 if (!result.empty()) return result;
73 }
74
75 return "";
76}
77
78std::string TreeSitterParser::get_function_name(TSNode func_node, const std::string& source) {
79 return extract_name_generic(func_node, source);
80}
81
82TSNode TreeSitterParser::get_body(TSNode node) {
83 TSNode body = ts_node_child_by_field_name(node, "body", strlen("body"));
84
85 if (!ts_node_is_null(body))
86 return body;
87
88 uint32_t count = ts_node_child_count(node);
89
90 for (uint32_t i = 0; i < count; i++) {
91 TSNode child = ts_node_child(node, i);
92 std::string_view type = ts_node_type(child);
94 return child;
95 }
96 }
97
98 return TSNode{};
99}
100
101void TreeSitterParser::collect_functions(
102 TSNode node,
103 const std::string& source,
104 const fs::path& relative_path,
105 const std::shared_ptr<TSTree>& tree,
106 std::function<void(const FunctionData&)> callback
107) {
108 std::string_view type = ts_node_type(node);
110 TSPoint start = ts_node_start_point(node);
111 TSPoint end = ts_node_end_point(node);
112
113 TSNode body = get_body(node);
114
115 if (is_function_empty(body)) return;
116
117 std::string function_name = get_function_name(node, source);
118 if (function_name.empty()) return;
119
120 TSPoint body_start = ts_node_start_point(body);
121 uint32_t start_byte = ts_node_start_byte(node);
122 uint32_t end_byte = ts_node_end_byte(node);
123
124 std::string signature = get_full_signature(node, source);
125 std::string code = source.substr(start_byte + signature.size(), end_byte - (start_byte + signature.size()));
126
127 FunctionData function;
128 function.path = relative_path.string();
129 function.function_name = function_name;
130
131 auto source = std::make_shared<SourceFeature>();
132 source->code = code;
133 function.add_feature(source);
134
135 auto ast = std::make_shared<ASTFeature>();
136 ast->tree = tree;
137 ast->root = body;
138 function.add_feature(ast);
139
140 auto metadata = std::make_shared<MetadataFeature>();
141 metadata->signature = signature;
142 metadata->line_declaration = start.row;
143 metadata->start_number_line = body_start.row;
144 metadata->end_number_line = end.row;
145 function.add_feature(metadata);
146
147 callback(function);
148 }
149
150 uint32_t count = ts_node_child_count(node);
151 for (uint32_t i = 0; i < count; i++) {
152 collect_functions(
153 ts_node_child(node, i),
154 source,
155 relative_path,
156 tree,
157 callback
158 );
159 }
160}
161
163 const fs::path& file_path,
164 const fs::path& relative_path,
165 const std::string& source_code,
166 std::function<void(const FunctionData&)> callback
167) {
168 auto lang_name = detect_language(file_path);
169
170 if (lang_name.empty()) return;
171
172 static auto language_map = get_language_map();
173 auto it = language_map.find(lang_name);
174 if (it == language_map.end()) {
175 std::cerr << "Language not found: " << lang_name << "\n";
176 return;
177 }
178 TSLanguage* language = it->second();
179
180 static TSLanguage* current_language = nullptr;
181 static std::unique_ptr<TSParser, decltype(&ts_parser_delete)> parser(
182 ts_parser_new(),
183 ts_parser_delete
184 );
185 if (current_language != language) {
186 if (!ts_parser_set_language(parser.get(), language)) {
187 return;
188 }
189 current_language = language;
190 }
191
192 std::shared_ptr<TSTree> tree(
193 ts_parser_parse_string(
194 parser.get(),
195 nullptr,
196 source_code.c_str(),
197 source_code.size()
198 ),
199 ts_tree_delete
200 );
201 TSNode root_node = ts_tree_root_node(tree.get());
202
203 collect_functions(root_node, source_code, relative_path, tree, callback);
204
205 tree.reset();
206}
static std::string get_node_text(TSNode node, const std::string &source)
static bool is_block_node(std::string_view type)
static bool is_function_node(std::string_view type)
std::string function_name
Name of the function.
std::string path
void add_feature(std::shared_ptr< T > feature)
static void process_file(const fs::path &file_path, const fs::path &relative_path, const std::string &source_code, std::function< void(const FunctionData &)> callback)
std::unordered_map< std::string, std::string > get_extension_map()
std::string extract_name_generic(TSNode node, const std::string &source)
std::unordered_map< std::string, TSLanguage *(*)()> get_language_map()
Defines utility functions used across all files.