from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from tqdm.auto import tqdm
from urllib.request import urlretrieve
from zipfile import ZipFile
import matplotlib.pyplot as plt
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def download_and_unzip(url, save_path):
print(f"Downloading and extracting assets....", end="")
# Downloading zip file using urllib package.
urlretrieve(url, save_path)
# Extracting zip file using the zipfile package.
with ZipFile(save_path) as z:
# Extract ZIP file contents in the same directory.
z.extractall(os.path.split(save_path)[0])
print("\nInvalid file.", e)
def read_image(image_path):
:param image_path: String, path to the input image.
image = Image.open(image_path).convert('RGB')
def segment_into_lines(image):
# Convert the PIL image to a NumPy array and grayscale
img_array = np.array(image.convert('L'))
# Invert the image (since text is usually darker than background)
img_array = 255 - img_array
# Binarize the image using Otsu's thresholding
_, binary = cv2.threshold(img_array, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
# Define a horizontal kernel to detect text lines
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (binary.shape[1], 1))
# Use morphological operations to detect lines
dilated = cv2.dilate(binary, kernel, iterations=5)
# Find contours of the lines
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort contours from top to bottom
contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
x, y, w, h = cv2.boundingRect(cnt)
# Crop the line from the original image
line_img = image.crop((0, y, image.width, y + h))
line_images.append(line_img)
bounding_boxes.append((0, y, image.width, y + h))
return line_images, bounding_boxes
def ocr_multiline(image, processor, model):
# Segment the image into lines and get bounding boxes
line_images, bounding_boxes = segment_into_lines(image)
for line_image in line_images:
# Process each line image
pixel_values = processor(line_image, return_tensors='pt').pixel_values.to(device)
generated_ids = model.generate(pixel_values)
line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
recognized_text += line_text + "\n"
return recognized_text, bounding_boxes
def eval_new_data(data_path=None, num_samples=4, model=None):
image_paths = glob.glob(data_path)
for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
image = read_image(image_path)
text, bounding_boxes = ocr_multiline(image, processor, model)
plt.figure(figsize=(10, 6))
for bbox in bounding_boxes:
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='red', linewidth=2)
URL = r"https://www.dropbox.com/scl/fi/jz74me0vc118akmv5nuzy/images.zip?rlkey=54flzvhh9xxh45czb1c8n3fp3&dl=1"
asset_zip_path = os.path.join(os.getcwd(), "images.zip")
# Download if assets ZIP does not exist.
if not os.path.exists(asset_zip_path):
download_and_unzip(URL, asset_zip_path)
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
data_path=os.path.join('images', 'handwritten', '*'),