Skip to content
Snippets Groups Projects
bornplot2.py 4.36 KiB
"""
Utilities to plot form factors of particles in Born approximation
"""
import math, numpy
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import rc

rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
rc('text', usetex=True)
rc('image', cmap='inferno')
mpl.rcParams['image.interpolation'] = 'none'

import bornagain as ba
from bornagain import ba_plot as bp


class NamedResult:

    def __init__(self, result, title=""):
        self.result = result
        self.title = title


def save_results(namedResults, name):
    # Output to data file, image file, and display.
    nDigits = int(math.log10(len(namedResults))) + 1
    formatN = "%" + str(nDigits) + "i"
    for i in range(len(namedResults)):
        fname = name + "." + (formatN % i) + ".int"
        print(fname)
        numpy.savetxt(fname, namedResults[i].result.array())


class MultiPlot:

    def __init__(self, n, ncol, **kwargs):
        self.n = n
        self.ncol = ncol
        self.nrow = 1 + (self.n - 1) // self.ncol

        # Parameters as fraction of subfig size.
        yskip = 0.2
        bottomskip = yskip
        topskip = yskip/2
        xskip = 0.18
        leftskip = xskip
        rightskip = 0.28 + ncol*0.03
        xtot = self.ncol*1.0 + (self.ncol - 1)*xskip + leftskip + rightskip
        ytot = self.nrow*1.0 + (self.nrow - 1)*yskip + bottomskip + topskip

        # We need parameters as fraction of total fig size.
        self.xskip = xskip/xtot
        self.leftskip = leftskip/xtot
        self.rightskip = rightskip/xtot
        self.yskip = yskip/ytot
        self.bottomskip = bottomskip/ytot
        self.topskip = topskip/ytot

        # Set total figure dimensions.
        ftot = 5
        self.fontsize = 18 + 36.0/(ncol + 2)
        # Create the figure 'fig' and its subplots axes ('tmp'->'axes').
        self.fig, tmp = plt.subplots(self.nrow,
                                     self.ncol,
                                     figsize=(ftot*xtot, ftot*ytot))
        if n > 1:
            self.axes = tmp.flat
        else:
            self.axes = [tmp]

        # Adjust whitespace around and between subfigures.
        plt.subplots_adjust(wspace=self.xskip,
                            hspace=self.yskip,
                            left=self.leftskip,
                            right=1 - self.rightskip,
                            bottom=self.bottomskip,
                            top=1 - self.topskip)

    def plot_colorbar(self, im):
        # Plot the color scale.
        cbar_ax = self.fig.add_axes([
            1 - self.rightskip + 0.4*self.xskip, self.bottomskip,
            0.25*self.xskip, 1 - self.bottomskip - self.topskip
        ])
        cb = self.fig.colorbar(im, cax=cbar_ax)
        cb.set_label(r'$\left|F(q)\right|^2/V^{\,2}$',
                     fontsize=self.fontsize)


def make_plot_row(namedResults, name, **kwargs):
    make_plot(namedResults, name, len(namedResults), **kwargs)


def make_plot(namedResults, name, ncol, **kwargs):
    """
    Make a plot consisting of one detector image for each Result in results,
    plus one common color scale.

    :param results: List of simulation results
    :param det: Detector
    :param name: Filename for multiplot during save
    :param ncol: Number of columns in multiplot
    """
    multiPlot = MultiPlot(len(namedResults), ncol, **kwargs)

    # Always the same color scale, to facilitate comparisons between figures.
    norm = mpl.colors.LogNorm(1e-8, 1)
    # Plot the subfigures.
    for i in range(len(namedResults)):
        item = namedResults[i]
        ax = multiPlot.axes[i]
        axes_limits = bp.get_axes_limits(item.result, ba.Axes.UNDEFINED)
        im = ax.imshow(item.result.array(),
                       norm=norm,
                       extent=axes_limits,
                       aspect=1)
        ax.set_xlabel(r'$\phi_{\rm f} (^{\circ})$',
                      fontsize=multiPlot.fontsize)
        if i % ncol == 0:
            ax.set_ylabel(r'$\alpha_{\rm f} (^{\circ})$',
                          fontsize=multiPlot.fontsize)
        if item.title != "":
            ax.set_title(item.title, fontsize=multiPlot.fontsize)
        ax.tick_params(axis='both',
                       which='major',
                       labelsize=multiPlot.fontsize*21/24)

    multiPlot.plot_colorbar(im)

    # Show or export
    plt.savefig(name + ".pdf", format="pdf", bbox_inches='tight')
    # plt.show()