Source code for sphinxcontrib.bibtex.domain

"""
    Classes and methods to maintain any bibtex information that is stored
    outside the doctree.

    .. autoclass:: Citation
        :members:

    .. autoclass:: BibtexDomain
        :members:
"""

import ast
import re
from typing import (
    TYPE_CHECKING,
    Dict,
    Iterable,
    List,
    NamedTuple,
    Optional,
    Set,
    Tuple,
    cast,
)

import docutils.frontend
import docutils.nodes
import docutils.parsers.rst
import docutils.utils
import pybtex.plugin
import pybtex_docutils
import sphinx.util
from pybtex.richtext import Tag
from pybtex.style import FormattedEntry
from pybtex.style.template import FieldIsMissing
from sphinx.domains import Domain, ObjType
from sphinx.errors import ExtensionError
from sphinx.locale import _

import sphinxcontrib.bibtex.plugin

from .bibfile import BibData, normpath_filename, process_bibdata
from .citation_target import CitationTarget, parse_citation_targets
from .roles import CiteRole
from .style.referencing import BaseReferenceStyle, format_references
from .style.template import SphinxReferenceInfo

if TYPE_CHECKING:
    from pybtex.database import Entry
    from pybtex.style.formatting import BaseStyle
    from sphinx.addnodes import pending_xref
    from sphinx.application import Sphinx
    from sphinx.builders import Builder
    from sphinx.environment import BuildEnvironment

    from .directives import BibliographyKey, BibliographyValue
    from .roles import CitationRef

logger = sphinx.util.logging.getLogger(__name__)


def _raise_invalid_node(node):
    """Helper method to raise an exception when an invalid node is
    visited.
    """
    raise ValueError("invalid node %s in filter expression" % node)


class _FilterVisitor(ast.NodeVisitor):

    """Visit the abstract syntax tree of a parsed filter expression."""

    entry = None
    """The bibliographic entry to which the filter must be applied."""

    cited_docnames = False
    """The documents where the entry is cited (empty if not cited)."""

    def __init__(self, entry, docname, cited_docnames):
        self.entry = entry
        self.docname = docname
        self.cited_docnames = cited_docnames

    def visit_Module(self, node):
        if len(node.body) != 1:
            raise ValueError("filter expression cannot contain multiple expressions")
        return self.visit(node.body[0])

    def visit_Expr(self, node):
        return self.visit(node.value)

    def visit_BoolOp(self, node):
        outcomes = (self.visit(value) for value in node.values)
        if isinstance(node.op, ast.And):
            return all(outcomes)
        elif isinstance(node.op, ast.Or):
            return any(outcomes)
        else:  # pragma: no cover
            # there are no other boolean operators
            # so this code should never execute
            assert False, "unexpected boolean operator %s" % node.op

    def visit_UnaryOp(self, node):
        if isinstance(node.op, ast.Not):
            return not self.visit(node.operand)
        else:
            _raise_invalid_node(node)

    def visit_BinOp(self, node):
        left = self.visit(node.left)
        op = node.op
        right = self.visit(node.right)
        if isinstance(op, ast.Mod):
            # modulo operator is used for regular expression matching
            if not isinstance(left, str):
                raise ValueError("expected a string on left side of %s" % node.op)
            if not isinstance(right, str):
                raise ValueError("expected a string on right side of %s" % node.op)
            return re.search(right, left, re.IGNORECASE)
        elif isinstance(op, ast.BitOr):
            return left | right
        elif isinstance(op, ast.BitAnd):
            return left & right
        else:
            _raise_invalid_node(node)

    def visit_Compare(self, node):
        # keep it simple: binary comparators only
        if len(node.ops) != 1:
            raise ValueError("syntax for multiple comparators not supported")
        left = self.visit(node.left)
        op = node.ops[0]
        right = self.visit(node.comparators[0])
        if isinstance(op, ast.Eq):
            return left == right
        elif isinstance(op, ast.NotEq):
            return left != right
        elif isinstance(op, ast.Lt):
            return left < right
        elif isinstance(op, ast.LtE):
            return left <= right
        elif isinstance(op, ast.Gt):
            return left > right
        elif isinstance(op, ast.GtE):
            return left >= right
        elif isinstance(op, ast.In):
            return left in right
        elif isinstance(op, ast.NotIn):
            return left not in right
        else:
            # not used currently: ast.Is | ast.IsNot
            _raise_invalid_node(op)

    def visit_Name(self, node):
        """Calculate the value of the given identifier."""
        id_ = node.id
        if id_ == "type":
            return self.entry.type.lower()
        elif id_ == "key":
            return self.entry.key.lower()
        elif id_ == "cited":
            return bool(self.cited_docnames)
        elif id_ == "docname":
            return self.docname
        elif id_ == "docnames":
            return self.cited_docnames
        elif id_ == "author" or id_ == "editor":
            if id_ in self.entry.persons:
                return " and ".join(
                    str(person)  # XXX needs fix in pybtex?
                    for person in self.entry.persons[id_]
                )
            else:
                return ""
        else:
            return self.entry.fields.get(id_, "")

    def visit_Set(self, node):
        return frozenset(self.visit(elt) for elt in node.elts)

    # NameConstant is Python 3.4 only
    def visit_NameConstant(self, node):
        return node.value  # pragma: no cover

    # Constant is Python 3.6+ only
    # Since 3.8 Num, Str, Bytes, NameConstant and Ellipsis are just Constant
    def visit_Constant(self, node):
        return node.value

    # Not used on 3.8+
    def visit_Str(self, node):
        return node.s  # pragma: no cover

    def generic_visit(self, node):
        _raise_invalid_node(node)


def get_docnames(env):
    """Get document names in order."""
    rel = env.collect_relations()
    docname = (
        env.config.master_doc if sphinx.version_info < (4, 0) else env.config.root_doc
    )
    docnames = set()
    while docname is not None:
        docnames.add(docname)
        yield docname
        parent, prevdoc, nextdoc = rel[docname]
        docname = nextdoc
    for docname in sorted(env.found_docs - docnames):
        yield docname


[docs]class Citation(NamedTuple): """Information about a citation.""" citation_id: str #: Unique id of this citation. bibliography_key: "BibliographyKey" #: Key of its bibliography directive. key: str #: Key (with prefix). entry: "Entry" #: Entry from pybtex. formatted_entry: "FormattedEntry" #: Formatted entry for bibliography. tooltip_entry: Optional["FormattedEntry"] #: Formatted entry for tooltip.
def env_updated(app: "Sphinx", env: "BuildEnvironment") -> Iterable[str]: domain = cast(BibtexDomain, env.get_domain("cite")) return domain.env_updated() def parse_header(header: str, source_path: str): parser = docutils.parsers.rst.Parser() # note: types stub for docutils doesn't know about components argument settings = docutils.frontend.OptionParser( components=(docutils.parsers.rst.Parser,) # type: ignore ).get_default_values() document = docutils.utils.new_document(source_path, settings) parser.parse(header, document) return document[0]
[docs]class BibtexDomain(Domain): """Sphinx domain for the bibtex extension.""" name = "cite" label = "BibTeX Citations" data_version = 4 initial_data = dict( bibdata=BibData( encoding="", bibfiles={}, data=pybtex.database.BibliographyData() ), bibliography_header=docutils.nodes.container(), bibliographies={}, citations=[], citation_refs=[], ) backend = pybtex_docutils.Backend() reference_style: BaseReferenceStyle @property def bibdata(self) -> BibData: """Information about the bibliography files.""" return self.data["bibdata"] @property def bibliography_header(self) -> docutils.nodes.Element: return self.data["bibliography_header"] @property def bibliographies(self) -> Dict["BibliographyKey", "BibliographyValue"]: """Map storing information about each bibliography directive.""" return self.data["bibliographies"] @property def citations(self) -> List[Citation]: """Citation data.""" return self.data["citations"] @property def citation_refs(self) -> List["CitationRef"]: """Citation reference data.""" return self.data["citation_refs"] def __init__(self, env: "BuildEnvironment"): # set up referencing style style = sphinxcontrib.bibtex.plugin.find_plugin( "sphinxcontrib.bibtex.style.referencing", env.app.config.bibtex_reference_style, ) self.reference_style = style() # set up object types and roles for referencing style role_names = self.reference_style.role_names() self.object_types = dict( citation=ObjType(_("citation"), *role_names, searchprio=-1), ) self.roles = dict((name, CiteRole()) for name in role_names) # initialize the domain super().__init__(env) # connect env-updated env.app.connect("env-updated", env_updated) # check config if env.app.config.bibtex_bibfiles is None: raise ExtensionError("You must configure the bibtex_bibfiles setting") # update bib file information in the cache bibfiles = [ normpath_filename(env, "/" + bibfile) for bibfile in env.app.config.bibtex_bibfiles ] self.data["bibdata"] = process_bibdata( self.bibdata, bibfiles, env.app.config.bibtex_encoding ) # parse bibliography header header = getattr(env.app.config, "bibtex_bibliography_header") if header: self.data["bibliography_header"] += parse_header( header, "bibliography_header" )
[docs] def clear_doc(self, docname: str) -> None: self.data["citations"] = [ citation for citation in self.citations if citation.bibliography_key.docname != docname ] self.data["citation_refs"] = [ ref for ref in self.citation_refs if ref.docname != docname ] for bib_key in list(self.bibliographies.keys()): if bib_key.docname == docname: del self.bibliographies[bib_key]
[docs] def merge_domaindata(self, docnames: List[str], otherdata: Dict) -> None: for bib_key, bib_value in otherdata["bibliographies"].items(): if bib_key.docname in docnames: self.bibliographies[bib_key] = bib_value for citation_ref in otherdata["citation_refs"]: if citation_ref.docname in docnames: self.citation_refs.append(citation_ref)
# 'citations' domain data calculated in env_updated def env_updated(self) -> Iterable[str]: # This function is called when all doctrees are parsed, # but before any post transforms are applied. We use it to # determine which citations will be added to which bibliography # directive, and also to format the labels. We need to format # the labels here because they must be known when resolve_xref is # called. self.citations.clear() # might have been restored from pickle docnames = list(get_docnames(self.env)) # we keep track of this to quickly check for duplicates used_keys: Set[str] = set() used_labels: Dict[str, str] = {} for bibliography_key, bibliography in self.bibliographies.items(): for entry, formatted_entry, tooltip_entry in self.get_formatted_entries( bibliography_key, docnames, self.env.app.config.bibtex_tooltips, self.env.app.config.bibtex_tooltips_style, ): key = bibliography.keyprefix + formatted_entry.key if bibliography.list_ == "citation" and key in used_keys: logger.warning( 'duplicate citation for key "%s"' % key, location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="duplicate_citation", ) self.citations.append( Citation( citation_id=bibliography.citation_nodes[key]["ids"][0], bibliography_key=bibliography_key, key=key, entry=entry, formatted_entry=formatted_entry, tooltip_entry=tooltip_entry, ) ) if bibliography.list_ == "citation": used_keys.add(key) if formatted_entry.label not in used_labels: used_labels[formatted_entry.label] = key elif used_labels[formatted_entry.label] != key: # if used_label[label] == key then already # duplicate key warning logger.warning( 'duplicate label "%s" for keys "%s" and "%s"' % ( formatted_entry.label, used_labels[formatted_entry.label], key, ), location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="duplicate_label", ) return [] # expects list of updated docnames
[docs] def resolve_xref( self, env: "BuildEnvironment", fromdocname: str, builder: "Builder", typ: str, target: str, node: "pending_xref", contnode: docutils.nodes.Element, ) -> docutils.nodes.Element: """Replace node by list of citation references (one for each key).""" targets = parse_citation_targets(target) keys: Dict[str, CitationTarget] = {target2.key: target2 for target2 in targets} citations: Dict[str, Citation] = { cit.key: cit for cit in self.citations if cit.key in keys and self.bibliographies[cit.bibliography_key].list_ == "citation" } for key in keys: if key not in citations: logger.warning( 'could not find bibtex key "%s"' % key, location=node, type="bibtex", subtype="key_not_found", ) plaintext = pybtex.plugin.find_plugin("pybtex.backends", "plaintext")() references = [ ( citation.entry, citation.formatted_entry, SphinxReferenceInfo( builder=builder, fromdocname=fromdocname, todocname=citation.bibliography_key.docname, citation_id=citation.citation_id, title=( citation.tooltip_entry.text.render(plaintext).replace( "\\url ", "" ) if citation.tooltip_entry else None ), pre_text=keys[citation.key].pre, post_text=keys[citation.key].post, ), ) for citation in citations.values() ] formatted_references = format_references(self.reference_style, typ, references) result_node = docutils.nodes.inline(rawsource=target) result_node += formatted_references.render(self.backend) return result_node
[docs] def resolve_any_xref( self, env: "BuildEnvironment", fromdocname: str, builder: "Builder", target: str, node: "pending_xref", contnode: docutils.nodes.Element, ) -> List[Tuple[str, docutils.nodes.Element]]: """Replace node by list of citation references (one for each key), provided that the target has citation keys. """ keys = [key.strip() for key in target.split(",")] citations: Set[str] = { cit.key for cit in self.citations if cit.key in keys and self.bibliographies[cit.bibliography_key].list_ == "citation" } if any(key in citations for key in keys): result_node = self.resolve_xref( env, fromdocname, builder, "p", target, node, contnode ) return [("p", result_node)] else: return []
[docs] def get_all_cited_keys(self, docnames): """Yield all citation keys for given *docnames* in order, then ordered by citation order. """ for citation_ref in sorted( self.citation_refs, key=lambda c: docnames.index(c.docname) ): for target in citation_ref.targets: yield target.key
[docs] def get_entries(self, bibfiles: List[str]) -> Iterable["Entry"]: """Return all bibliography entries from the bib files, unsorted (i.e. in order of appearance in the bib files. """ for bibfile in bibfiles: for key in self.bibdata.bibfiles[bibfile].keys: yield self.bibdata.data.entries[key]
[docs] def get_filtered_entries( self, bibliography_key: "BibliographyKey" ) -> Iterable[Tuple[str, "Entry"]]: """Return unsorted bibliography entries filtered by the filter expression. """ bibliography = self.bibliographies[bibliography_key] for entry in self.get_entries(bibliography.bibfiles): key = bibliography.keyprefix + entry.key cited_docnames = { citation_ref.docname for citation_ref in self.citation_refs if key in {target.key for target in citation_ref.targets} } visitor = _FilterVisitor( entry=entry, docname=bibliography_key.docname, cited_docnames=cited_docnames, ) try: success = visitor.visit(bibliography.filter_) except ValueError as err: logger.warning( "syntax error in :filter: expression; %s" % err, location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="filter_syntax_error", ) # recover by falling back to the default success = bool(cited_docnames) if success or entry.key in bibliography.keys: yield key, entry
[docs] def get_sorted_entries( self, bibliography_key: "BibliographyKey", docnames: List[str] ) -> Iterable[Tuple[str, "Entry"]]: """Return filtered bibliography entries sorted by citation order.""" entries = dict(self.get_filtered_entries(bibliography_key)) for key in self.get_all_cited_keys(docnames): try: entry = entries.pop(key) except KeyError: pass else: yield key, entry # then all remaining keys, in order of bibliography file for key, entry in entries.items(): yield key, entry
[docs] def get_formatted_entries( self, bibliography_key: "BibliographyKey", docnames: List[str], tooltips: bool, tooltips_style: str, ) -> Iterable[Tuple["Entry", "FormattedEntry", Optional["FormattedEntry"]]]: """Get sorted bibliography entries along with their pybtex labels, with additional sorting and formatting applied from the pybtex style. """ bibliography = self.bibliographies[bibliography_key] entries = dict(self.get_sorted_entries(bibliography_key, docnames)) style: BaseStyle = cast( "BaseStyle", pybtex.plugin.find_plugin("pybtex.style.formatting", bibliography.style)(), ) style2: Optional[BaseStyle] = ( ( cast( "BaseStyle", pybtex.plugin.find_plugin( "pybtex.style.formatting", tooltips_style )(), ) if tooltips_style else style ) if tooltips else None ) sorted_entries: Iterable[Entry] = style.sort(entries.values()) labels = style.format_labels(sorted_entries) for label, entry in zip(labels, sorted_entries): try: yield ( entry, style.format_entry(bibliography.labelprefix + label, entry), style2.format_entry(bibliography.labelprefix + label, entry) if style2 else None, ) except FieldIsMissing as exc: logger.warning( str(exc), location=(bibliography_key.docname, bibliography.line), type="bibtex", subtype="missing_field", ) formatted_error_entry = FormattedEntry( entry.key, Tag("b", str(exc)), bibliography.labelprefix + label ) yield entry, formatted_error_entry, None