Skip to content
Snippets Groups Projects
predict.py 2.29 KiB
Newer Older
ivan.rigo's avatar
ivan.rigo committed
import cv2
from matplotlib import pyplot, patches
import matplotlib
import tools
import numpy as np
import pandas as pd


def display_prediction(data):
  colors = {
      "S-TOTAL": (255,0,0),
      "S-DATE": (0,255,0),
      "S-ADDRESS": (0,0, 255),
      "S-COMPANY": (255,255,0),
      "O": (192,192,192)
  }

  imagename = data[0].split(".")[0] + ".jpg"
  print("Filename:",imagename)
  image_path = str("SROIE2019/test/img/" + imagename)

  img=cv2.imread(image_path)
  img_prediction=cv2.imread(image_path)

  data = data[1]
  for bbox, category, prediction_category in zip(data['bbox'], data['true_category'], data['prediction_category']):
    (x1, y1, x2, y2) = [int(coordinate) for coordinate in bbox.split()]

    img_prediction = cv2.rectangle(img_prediction, (x1, y1), (x2, y2), colors[prediction_category], 2 if "O" in prediction_category else 4)
    img = cv2.rectangle(img, (x1, y1), (x2, y2), colors[category], 2 if "O" in category else 4)

  matplotlib.rcParams['figure.figsize'] = 15 ,18

  cv2.imwrite("prediction.png", img_prediction)

  # Plot
  fig, ax = matplotlib.pyplot.subplots(1,2)
  ax[0].set_title("Original", fontsize= 30)
  ax[0].imshow(img)
  ax[1].set_title("Prediction", fontsize= 30)
  ax[1].imshow(img_prediction)

  # Legend
  handles = [
      patches.Patch(color='yellow', label='Company'),
      patches.Patch(color='blue', label='Address'),
      patches.Patch(color='green', label='Date'),
      patches.Patch(color='red', label='Total'),
      patches.Patch(color='gray', label='Other')
  ]

  fig.legend(handles=handles, prop={'size': 25}, loc='lower center')
  pyplot.show()



if __name__ == "__main__":
    data = pd.read_csv("dataset/test_image.txt", delimiter="\t", names=["name", "bbox", "size", "image"])
    data_category = pd.read_csv("dataset/test.txt", delimiter="\t", names=["name", "true_category"]).drop(columns=["name"])
    data_prediction_category = pd.read_csv("SROIE2019/unilm/layoutlm/deprecated/examples/seq_labeling/output/test_predictions.txt", delimiter=" ", names=["name", "prediction_category"]).drop(columns=["name"])

    data_merge = data.merge(data_category, left_index=True, right_index=True)
    merged = data_merge.merge(data_prediction_category, left_index=True, right_index=True)
    merged_groups = list(merged.groupby("image"))

ivan.rigo's avatar
ivan.rigo committed
    display_prediction(merged_groups[0])