Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import cv2
from matplotlib import pyplot, patches
import matplotlib
import tools
import numpy as np
import pandas as pd
def display_prediction(data):
colors = {
"S-TOTAL": (255,0,0),
"S-DATE": (0,255,0),
"S-ADDRESS": (0,0, 255),
"S-COMPANY": (255,255,0),
"O": (192,192,192)
}
imagename = data[0].split(".")[0] + ".jpg"
print("Filename:",imagename)
image_path = str("SROIE2019/test/img/" + imagename)
img=cv2.imread(image_path)
img_prediction=cv2.imread(image_path)
data = data[1]
for bbox, category, prediction_category in zip(data['bbox'], data['true_category'], data['prediction_category']):
(x1, y1, x2, y2) = [int(coordinate) for coordinate in bbox.split()]
img_prediction = cv2.rectangle(img_prediction, (x1, y1), (x2, y2), colors[prediction_category], 2 if "O" in prediction_category else 4)
img = cv2.rectangle(img, (x1, y1), (x2, y2), colors[category], 2 if "O" in category else 4)
matplotlib.rcParams['figure.figsize'] = 15 ,18
cv2.imwrite("prediction.png", img_prediction)
# Plot
fig, ax = matplotlib.pyplot.subplots(1,2)
ax[0].set_title("Original", fontsize= 30)
ax[0].imshow(img)
ax[1].set_title("Prediction", fontsize= 30)
ax[1].imshow(img_prediction)
# Legend
handles = [
patches.Patch(color='yellow', label='Company'),
patches.Patch(color='blue', label='Address'),
patches.Patch(color='green', label='Date'),
patches.Patch(color='red', label='Total'),
patches.Patch(color='gray', label='Other')
]
fig.legend(handles=handles, prop={'size': 25}, loc='lower center')
pyplot.show()
if __name__ == "__main__":
data = pd.read_csv("dataset/test_image.txt", delimiter="\t", names=["name", "bbox", "size", "image"])
data_category = pd.read_csv("dataset/test.txt", delimiter="\t", names=["name", "true_category"]).drop(columns=["name"])
data_prediction_category = pd.read_csv("SROIE2019/unilm/layoutlm/deprecated/examples/seq_labeling/output/test_predictions.txt", delimiter=" ", names=["name", "prediction_category"]).drop(columns=["name"])
data_merge = data.merge(data_category, left_index=True, right_index=True)
merged = data_merge.merge(data_prediction_category, left_index=True, right_index=True)
merged_groups = list(merged.groupby("image"))