Source code for joeynmt.plotting

# coding: utf-8
"""
Plot attentions
"""
from typing import List, Optional

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rcParams
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.figure import Figure

matplotlib.use("Agg")
# matplotlib.font_manager.fontManager.addfont("ipaexg.ttf")


[docs] def plot_heatmap( scores: np.ndarray, column_labels: List[str], row_labels: List[str], output_path: Optional[str] = None, dpi: int = 300, ) -> Figure: """ Plotting function that can be used to visualize (self-)attention. Plots are saved if `output_path` is specified, in format that this file ends with ('pdf' or 'png'). :param scores: attention scores :param column_labels: labels for columns (e.g. target tokens) :param row_labels: labels for rows (e.g. source tokens) :param output_path: path to save to :param dpi: set resolution for matplotlib :return: pyplot figure """ if output_path is not None: assert output_path.endswith(".png") or output_path.endswith(".pdf"), \ "output path must have .png or .pdf extension" x_sent_len = len(column_labels) y_sent_len = len(row_labels) scores = scores[:y_sent_len, :x_sent_len] # check that cut off part didn't have any attention assert np.sum(scores[y_sent_len:, :x_sent_len]) == 0 # automatic label size labelsize = 25 * (10 / max(x_sent_len, y_sent_len)) # font config rcParams["xtick.labelsize"] = labelsize rcParams["ytick.labelsize"] = labelsize # rcParams['font.family'] = "IPAexGothic" # support CJK fig, ax = plt.subplots(figsize=(10, 10), dpi=dpi) plt.imshow( scores, cmap="viridis", aspect="equal", origin="upper", vmin=0.0, vmax=1.0, ) ax.xaxis.tick_top() ax.set_xticks(np.arange(scores.shape[1]) + 0, minor=False) ax.set_yticks(np.arange(scores.shape[0]) + 0, minor=False) ax.set_xticklabels(column_labels, minor=False, rotation="vertical") ax.set_yticklabels(row_labels, minor=False) plt.tight_layout() if output_path is not None: if output_path.endswith(".pdf"): pp = PdfPages(output_path) pp.savefig(fig) pp.close() else: if not output_path.endswith(".png"): output_path += ".png" plt.savefig(output_path) plt.close() return fig