| | import argparse |
| | import base64 |
| | from io import BytesIO |
| |
|
| | from PIL import Image |
| |
|
| | from handler import EndpointHandler, decode_base64_image |
| |
|
| |
|
| | def local_predict(prompts, encode_image): |
| | |
| | my_handler = EndpointHandler() |
| | if encode_image: |
| | response = my_handler({"inputs": prompts, "image": encode_image}) |
| | else: |
| | response = my_handler({"inputs": prompts}) |
| |
|
| | image = decode_base64_image(response["image"]) |
| | image.save("local_output.png") |
| |
|
| |
|
| | opt = argparse.ArgumentParser("Diffuser local test") |
| | opt.add_argument("-prompts", "--prompts", default="", type=str, help="Diffuser prompts") |
| | opt.add_argument("-image", "--image", default="", type=str, help="Init image") |
| | if __name__ == '__main__': |
| | args = opt.parse_args() |
| |
|
| | encoded_string = "" |
| | if args.image: |
| | with open(args.image, "rb") as image_file: |
| | encoded_string = base64.b64encode(image_file.read()).decode() |
| |
|
| | local_predict(args.prompts, encoded_string) |
| |
|