"""
Classes and methods to maintain any bibtex information that is stored
outside the doctree.
.. autoclass:: BibliographyKey
:members:
.. autoclass:: BibliographyValue
:members:
.. autoclass:: Citation
:members:
.. autoclass:: CitationRef
:members:
.. autoclass:: BibtexDomain
:members:
"""
import ast
from typing import List, Dict, NamedTuple, cast, Optional, Iterable, Tuple, Set
import docutils.frontend
import docutils.nodes
import docutils.parsers.rst
import docutils.utils
import sphinx.util
import re
from pybtex.database import Entry
from pybtex.plugin import find_plugin
import pybtex.style.formatting
from pybtex.style import FormattedEntry
from sphinx.addnodes import pending_xref
from sphinx.builders import Builder
from sphinx.domains import Domain
from sphinx.environment import BuildEnvironment
from sphinx.errors import ExtensionError
from sphinx.util.nodes import make_refnode
from .bibfile import BibFile, normpath_filename, process_bibfile
logger = sphinx.util.logging.getLogger(__name__)
def _raise_invalid_node(node):
"""Helper method to raise an exception when an invalid node is
visited.
"""
raise ValueError("invalid node %s in filter expression" % node)
class _FilterVisitor(ast.NodeVisitor):
"""Visit the abstract syntax tree of a parsed filter expression."""
entry = None
"""The bibliographic entry to which the filter must be applied."""
cited_docnames = False
"""The documents where the entry is cited (empty if not cited)."""
def __init__(self, entry, docname, cited_docnames):
self.entry = entry
self.docname = docname
self.cited_docnames = cited_docnames
def visit_Module(self, node):
if len(node.body) != 1:
raise ValueError(
"filter expression cannot contain multiple expressions")
return self.visit(node.body[0])
def visit_Expr(self, node):
return self.visit(node.value)
def visit_BoolOp(self, node):
outcomes = (self.visit(value) for value in node.values)
if isinstance(node.op, ast.And):
return all(outcomes)
elif isinstance(node.op, ast.Or):
return any(outcomes)
else: # pragma: no cover
# there are no other boolean operators
# so this code should never execute
assert False, "unexpected boolean operator %s" % node.op
def visit_UnaryOp(self, node):
if isinstance(node.op, ast.Not):
return not self.visit(node.operand)
else:
_raise_invalid_node(node)
def visit_BinOp(self, node):
left = self.visit(node.left)
op = node.op
right = self.visit(node.right)
if isinstance(op, ast.Mod):
# modulo operator is used for regular expression matching
if not isinstance(left, str):
raise ValueError(
"expected a string on left side of %s" % node.op)
if not isinstance(right, str):
raise ValueError(
"expected a string on right side of %s" % node.op)
return re.search(right, left, re.IGNORECASE)
elif isinstance(op, ast.BitOr):
return left | right
elif isinstance(op, ast.BitAnd):
return left & right
else:
_raise_invalid_node(node)
def visit_Compare(self, node):
# keep it simple: binary comparators only
if len(node.ops) != 1:
raise ValueError("syntax for multiple comparators not supported")
left = self.visit(node.left)
op = node.ops[0]
right = self.visit(node.comparators[0])
if isinstance(op, ast.Eq):
return left == right
elif isinstance(op, ast.NotEq):
return left != right
elif isinstance(op, ast.Lt):
return left < right
elif isinstance(op, ast.LtE):
return left <= right
elif isinstance(op, ast.Gt):
return left > right
elif isinstance(op, ast.GtE):
return left >= right
elif isinstance(op, ast.In):
return left in right
elif isinstance(op, ast.NotIn):
return left not in right
else:
# not used currently: ast.Is | ast.IsNot
_raise_invalid_node(op)
def visit_Name(self, node):
"""Calculate the value of the given identifier."""
id_ = node.id
if id_ == 'type':
return self.entry.type.lower()
elif id_ == 'key':
return self.entry.key.lower()
elif id_ == 'cited':
return bool(self.cited_docnames)
elif id_ == 'docname':
return self.docname
elif id_ == 'docnames':
return self.cited_docnames
elif id_ == 'author' or id_ == 'editor':
if id_ in self.entry.persons:
return u' and '.join(
str(person) # XXX needs fix in pybtex?
for person in self.entry.persons[id_])
else:
return u''
else:
return self.entry.fields.get(id_, "")
def visit_Set(self, node):
return frozenset(self.visit(elt) for elt in node.elts)
# NameConstant is Python 3.4 only
def visit_NameConstant(self, node):
return node.value # pragma: no cover
# Constant is Python 3.6+ only
# Since 3.8 Num, Str, Bytes, NameConstant and Ellipsis are just Constant
def visit_Constant(self, node):
return node.value
# Not used on 3.8+
def visit_Str(self, node):
return node.s # pragma: no cover
def generic_visit(self, node):
_raise_invalid_node(node)
def get_docnames(env):
"""Ged document names in order."""
rel = env.collect_relations()
docname = env.config.master_doc
while docname is not None:
yield docname
parent, prevdoc, nextdoc = rel[docname]
docname = nextdoc
[docs]class BibliographyKey(NamedTuple):
docname: str
id_: str
[docs]class BibliographyValue(NamedTuple):
"""Contains information about a bibliography directive."""
line: int #: Line number of the directive in the document.
bibfiles: List[str] #: List of bib files for this directive.
style: str #: The pybtex style.
list_: str #: The list type.
enumtype: str #: The sequence type (for enumerated lists).
start: int #: The start of the sequence (for enumerated lists).
labelprefix: str #: String prefix for pybtex generated labels.
keyprefix: str #: String prefix for citation keys.
filter_: ast.AST #: Parsed filter expression.
citation_nodes: Dict[str, docutils.nodes.citation] #: key -> citation node
[docs]class Citation(NamedTuple):
"""Information about a citation."""
citation_id: Optional[str] #: Unique id of this citation.
bibliography_key: BibliographyKey #: Key of its bibliography directive.
key: str #: Unique citation id used for referencing.
label: str #: Label (with brackets and label prefix).
formatted_entry: FormattedEntry #: Entry as formatted by pybtex.
[docs]class CitationRef(NamedTuple):
"""Information about a citation reference."""
citation_ref_id: str #: Unique id of this citation reference.
docname: str #: Document name.
line: int #: Line number.
keys: List[str] #: Citation keys (including key prefix).
[docs]class BibtexDomain(Domain):
"""Global bibtex extension information cache."""
name = 'cite'
label = 'BibTeX Citations'
data_version = 2
@property
def bibfiles(self) -> Dict[str, BibFile]:
"""Map each bib filename to some information about the file (including
the parsed data).
"""
return self.data.setdefault('bibfiles', {}) # filename -> cache
@property
def bibliographies(self) -> Dict[BibliographyKey, BibliographyValue]:
"""Map storing information about each bibliography directive."""
return self.data.setdefault('bibliographies', {}) # id -> cache
@property
def citations(self) -> List[Citation]:
"""Citation data."""
return self.data.setdefault('citations', [])
@property
def citation_refs(self) -> List[CitationRef]:
"""Citation reference data."""
return self.data.setdefault('citation_refs', [])
def __init__(self, env: BuildEnvironment):
super().__init__(env)
# check config
if env.app.config.bibtex_bibfiles is None:
raise ExtensionError(
"You must configure the bibtex_bibfiles setting")
# update bib file information in the cache
for bibfile in env.app.config.bibtex_bibfiles:
process_bibfile(
self.bibfiles, normpath_filename(env, "/" + bibfile),
env.app.config.bibtex_encoding)
# parse bibliography headers
for directive in ("bibliography", "footbibliography"):
conf_name = "bibtex_{0}_header".format(directive)
if not hasattr(env, conf_name):
parser = docutils.parsers.rst.Parser()
settings = docutils.frontend.OptionParser(
components=(docutils.parsers.rst.Parser,)
).get_default_values()
document = docutils.utils.new_document(
"{0}_header".format(directive), settings)
parser.parse(getattr(env.app.config, conf_name), document)
setattr(env, conf_name,
document[0] if len(document) > 0 else None)
[docs] def clear_doc(self, docname: str) -> None:
self.data['citations'] = [
citation for citation in self.citations
if citation.bibliography_key.docname != docname]
self.data['citation_refs'] = [
ref for ref in self.citation_refs
if ref.docname != docname]
for bib_key, bib_value in list(self.bibliographies.items()):
if bib_key.docname == docname:
del self.bibliographies[bib_key]
[docs] def merge_domaindata(self, docnames: List[str], otherdata: Dict) -> None:
for bib_key, bib_value in otherdata['bibliographies'].items():
if bib_key.docname in docnames:
self.bibliographies[bib_key] = bib_value
for citation_ref in otherdata['citation_refs']:
if citation_ref.docname in docnames:
self.citation_refs.append(citation_ref)
# citations created during check_consistency so never pickled
assert not self.citations
assert not otherdata['citations']
[docs] def check_consistency(self) -> None:
# This function is called when all doctrees are parsed,
# but before any post transforms are applied. We use it to
# determine which citations will be added to which bibliography
# directive, and also to format the labels. We need to format
# the labels and construct the citation ids here because they must be
# known when resolve_xref is called.
docnames = list(get_docnames(self.env))
# we keep track of this to quickly check for duplicates
used_keys = set()
used_labels: Dict[str, Set[str]] = {}
used_ids = set()
for bibliography_key, bibliography in self.bibliographies.items():
for formatted_entry in self.get_formatted_entries(
bibliography_key, docnames):
key = bibliography.keyprefix + formatted_entry.key
label = bibliography.labelprefix + formatted_entry.label
if bibliography.list_ != 'citation':
# no warning in this case, just don't generate link
citation_id = None
elif key in used_keys:
logger.warning(
'duplicate citation for key %s' % key,
location=(bibliography_key.docname, bibliography.line))
# no id for this one
citation_id = None
else:
citation_id = bibliography.citation_nodes[key]['ids'][0]
self.citations.append(Citation(
citation_id=citation_id,
bibliography_key=bibliography_key,
key=key,
label=label,
formatted_entry=formatted_entry,
))
used_keys.add(key)
used_labels.setdefault(label, set()).add(key)
used_ids.add(citation_id)
for label, keys in used_labels.items():
if len(keys) > 1:
logger.warning(
'duplicate label %s for keys %s' % (
label, ','.join(sorted(keys))))
[docs] def resolve_xref(self, env: BuildEnvironment, fromdocname: str,
builder: Builder, typ: str, target: str,
node: pending_xref, contnode: docutils.nodes.Element
) -> docutils.nodes.Element:
"""Replace node by list of citation references (one for each key)."""
keys = [key.strip() for key in target.split(',')]
if builder.name != 'latex':
node = docutils.nodes.inline(rawsource=target, text='[')
else:
node = docutils.nodes.inline(rawsource=target, text='')
# map citation keys that can be resolved to their citation data
citations = {
cit.key: cit for cit in self.citations
if cit.key in keys
and self.bibliographies[cit.bibliography_key].list_ == 'citation'}
for i, key in enumerate(keys):
try:
citation = citations[key]
except KeyError:
# TODO can handle missing reference warning using the domain
logger.warning('could not find bibtex key %s' % key)
node += docutils.nodes.inline('', key)
continue
refcontnode = docutils.nodes.inline('', citation.label)
if builder.name == 'latex':
# latex builder needs a citation_reference
refnode = docutils.nodes.citation_reference(
'', refcontnode,
docname=env.docname,
refname=citation.citation_id)
else:
# other builders can use general reference node
refnode = make_refnode(
builder, env.docname,
citation.bibliography_key.docname,
citation.citation_id, refcontnode)
node += refnode
if i != len(keys) - 1 and builder.name != 'latex':
node += docutils.nodes.Text(',')
if builder.name != 'latex':
node += docutils.nodes.Text(']')
return node
[docs] def get_all_cited_keys(self, docnames):
"""Yield all citation keys for given *docnames* in order, then
ordered by citation order.
"""
for citation_ref in sorted(
self.citation_refs, key=lambda c: docnames.index(c.docname)):
for key in citation_ref.keys:
yield key
[docs] def get_entries(
self, bibfiles: List[str]) -> Iterable[Entry]:
"""Return all bibliography entries from the bib files, unsorted (i.e.
in order of appearance in the bib files.
"""
for bibfile in bibfiles:
for entry in self.bibfiles[bibfile].data.entries.values():
yield entry
[docs] def get_filtered_entries(
self, bibliography_key: BibliographyKey
) -> Iterable[Tuple[str, Entry]]:
"""Return unsorted bibliography entries filtered by the filter
expression.
"""
bibliography = self.bibliographies[bibliography_key]
for entry in self.get_entries(bibliography.bibfiles):
key = bibliography.keyprefix + entry.key
cited_docnames = {
citation_ref.docname
for citation_ref in self.citation_refs
if key in citation_ref.keys
}
visitor = _FilterVisitor(
entry=entry,
docname=bibliography_key.docname,
cited_docnames=cited_docnames)
try:
success = visitor.visit(bibliography.filter_)
except ValueError as err:
logger.warning(
"syntax error in :filter: expression; %s" % err,
location=(bibliography_key.docname, bibliography.line))
# recover by falling back to the default
success = bool(cited_docnames)
if success:
yield key, entry
[docs] def get_sorted_entries(
self, bibliography_key: BibliographyKey, docnames: List[str]
) -> Iterable[Tuple[str, Entry]]:
"""Return filtered bibliography entries sorted by citation order."""
entries = dict(
self.get_filtered_entries(bibliography_key))
for key in self.get_all_cited_keys(docnames):
try:
entry = entries.pop(key)
except KeyError:
pass
else:
yield key, entry
# then all remaining keys, in order of bibliography file
for key, entry in entries.items():
yield key, entry
[docs] def get_formatted_entries(
self, bibliography_key: BibliographyKey, docnames: List[str]
) -> Iterable[FormattedEntry]:
"""Get sorted bibliography entries along with their pybtex labels,
with additional sorting and formatting applied from the pybtex style.
"""
bibliography = self.bibliographies[bibliography_key]
entries = dict(
self.get_sorted_entries(bibliography_key, docnames))
style = cast(
pybtex.style.formatting.BaseStyle,
find_plugin('pybtex.style.formatting', bibliography.style)())
return style.format_entries(entries.values())