| """ |
| Original work: |
| https://github.com/sangHa0411/CloneDetection/blob/main/utils/preprocessor.py |
| |
| Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) |
| |
| All credits to the original authors. |
| """ |
| import re |
| import torch |
| from transformers import Pipeline |
|
|
|
|
| class FunctionPreprocessor: |
| def get_function(self, code): |
| results = [] |
| fn_list = re.findall("\ndef [a-zA-Z0-9_]+\(", code) |
|
|
| for fn in fn_list: |
| results.append(fn[4:-1].strip()) |
| return results |
|
|
| def determine_function(self, code, function_name): |
| num = len(re.findall("[^a-zA-Z]" + function_name + "[^a-zA-Z]", code)) |
| return False if num <= 1 else True |
|
|
| def delete_function(self, code, name): |
| start_id, _ = re.search("def " + name, code).span() |
| ptr = start_id |
|
|
| while ptr < len(code) - 1: |
| if code[ptr] == "\n" and re.search("[a-zA-Z]", code[ptr + 1]) is not None: |
| break |
| ptr += 1 |
|
|
| if ptr != len(code) - 1: |
| end_id = ptr |
| code = code[:start_id] + code[end_id:] |
|
|
| return code |
|
|
| def preprocess(self, code): |
| code = "\n" + code |
| fn_list = self.get_function(code) |
| if len(fn_list) == 0: |
| return code |
|
|
| for fn in fn_list: |
| flag = self.determine_function(code, fn) |
|
|
| if flag == False: |
| code = self.delete_function(code, fn) |
|
|
| return code |
|
|
|
|
| class AnnotationPreprocessor: |
| def search(self, sen_list, string): |
| for i, sen in enumerate(sen_list): |
| if string in sen: |
| return i |
| return -1 |
|
|
| def delete_annotation_block(self, code, string): |
| sens = [sen for sen in code.split("\n")] |
|
|
| start_id = self.search(sens, string) |
| end_id = self.search(sens[start_id + 1 :], string) |
| if end_id != -1: |
| end_id += start_id + 1 |
| code = sens[:start_id] + sens[end_id + 1 :] |
| else: |
| code = sens[:start_id] + sens[start_id + 1 :] |
|
|
| code = "\n".join(code) |
| return code |
|
|
| def delete_block(self, code, string): |
| while string in code: |
| code = self.delete_annotation_block(code, string) |
| return code |
|
|
| def delete_annotation(self, code): |
| sens = code.split("\n") |
|
|
| sens_processed = [] |
| for sen in sens: |
| if "#" in sen: |
| index = sen.index("#") |
| sen = sen[:index] |
| sens_processed.append(sen) |
|
|
| return "\n".join(sens_processed) |
|
|
| def delete_import(self, code): |
| sens = code.split("\n") |
|
|
| sens_processed = [] |
| for sen in sens: |
| if "import" not in sen: |
| sens_processed.append(sen) |
|
|
| return "\n".join(sens_processed) |
|
|
| def preprocess(self, code): |
| code = self.delete_block(code, '"""') |
| code = self.delete_block(code, "'''") |
| code = self.delete_annotation(code) |
| code = self.delete_import(code) |
| code = re.sub("\s+", " ", code).strip() |
| return code |
|
|
|
|
| def preprocessor(code, instance): |
| processed_code = instance.preprocess(code) |
| return processed_code if processed_code.strip() else code |
|
|
|
|
| def token_to_inputs(feature): |
| inputs = {} |
| for k, v in feature.items(): |
| inputs[k] = torch.tensor(v).unsqueeze(0) |
|
|
| return inputs |
|
|
|
|
| class CloneDetectionPipeline(Pipeline): |
| fn_preprocessor = FunctionPreprocessor() |
| an_preprocessor = AnnotationPreprocessor() |
|
|
| def _sanitize_parameters(self, **kwargs): |
| preprocess_kwargs = {} |
| return preprocess_kwargs, {}, {} |
|
|
| def preprocess(self, inputs): |
| code1 = inputs[0] |
| code2 = inputs[1] |
| if code1.strip() == "" or code2.strip() == "": |
| ture_prob = float(code1.strip() == code2.strip()) |
| return {"skip": True, "output": {False: 1 - ture_prob, True: ture_prob}} |
|
|
| code1 = preprocessor( |
| preprocessor(code1, self.fn_preprocessor), self.an_preprocessor |
| ) |
| code2 = preprocessor( |
| preprocessor(code2, self.fn_preprocessor), self.an_preprocessor |
| ) |
|
|
| feature1 = self.tokenizer( |
| code1, code2, max_length=512, return_token_type_ids=False, truncation=True |
| ) |
| feature2 = self.tokenizer( |
| code2, code1, max_length=512, return_token_type_ids=False, truncation=True |
| ) |
|
|
| return { |
| "inputs1": token_to_inputs(feature1), |
| "inputs2": token_to_inputs(feature2), |
| } |
|
|
| def _forward(self, model_inputs): |
| if model_inputs.get("skip", False): |
| return model_inputs |
|
|
| inputs1 = model_inputs["inputs1"] |
| inputs2 = model_inputs["inputs2"] |
|
|
| logits1 = self.model(**inputs1).logits[0] |
| logits2 = self.model(**inputs2).logits[0] |
| logits = (logits1 + logits2) / 2 |
|
|
| return {"logits": logits} |
|
|
| def postprocess(self, model_outputs): |
| if model_outputs.get("skip", False): |
| return model_outputs["output"] |
|
|
| probs = model_outputs["logits"].softmax(-1).tolist() |
|
|
| return {False: probs[0], True: probs[1]} |
|
|