Source code for kgx.transformers.pandas_transformer
import re
import pandas as pd
import numpy as np
import logging, tarfile
from tempfile import TemporaryFile
from kgx.utils import make_path
from kgx.utils.kgx_utils import generate_edge_key
from kgx.transformers.transformer import Transformer
from typing import List, Dict
LIST_DELIMITER = '|'
_column_types = {
'publications': list,
'qualifiers': list,
'category': list,
'synonym': list,
'provided_by': list,
'same_as': list,
'negated': bool,
}
_extension_types = {
'csv': ',',
'tsv': '\t',
'txt': '|'
}
_archive_mode = {
'tar': 'r',
'tar.gz': 'r:gz',
'tar.bz2': 'r:bz2'
}
_archive_format = {
'w': 'tar',
'w:gz': 'tar.gz',
'w:bz2': 'tar.bz2'
}
[docs]class PandasTransformer(Transformer):
"""
Transformer that parses a pandas.DataFrame, and loads nodes and edges into a networkx.MultiDiGraph
"""
# TODO: Support parsing and export of neo4j-import tool compatible CSVs with appropriate headers
[docs] def parse(self, filename: str, input_format: str = 'csv', provided_by: str = None, **kwargs) -> None:
"""
Parse a CSV/TSV (or plain text) file.
The file can represent either nodes (nodes.csv) or edges (edges.csv) or both (data.tar),
where the tar archive contains nodes.csv and edges.csv
The file can also be data.tar.gz or data.tar.bz2
Parameters
----------
filename: str
File to read from
input_format: str
The input file format (``csv``, by default)
provided_by: str
Define the source providing the input file
kwargs: Dict
Any additional arguments
"""
if 'delimiter' not in kwargs:
# infer delimiter from file format
kwargs['delimiter'] = _extension_types[input_format]
if filename.endswith('.tar'):
mode = _archive_mode['tar']
elif filename.endswith('.tar.gz'):
mode = _archive_mode['tar.gz']
elif filename.endswith('.tar.bz2'):
mode = _archive_mode['tar.bz2']
else:
# file is not an archive
mode = None
if provided_by:
self.graph_metadata['provided_by'] = [provided_by]
if mode:
with tarfile.open(filename, mode=mode) as tar:
for member in tar.getmembers():
f = tar.extractfile(member)
df = pd.read_csv(f, **kwargs) # type: pd.DataFrame
if re.search('nodes.{}'.format(input_format), member.name):
self.load_nodes(df)
elif re.search('edges.{}'.format(input_format), member.name):
self.load_edges(df)
else:
raise Exception('Tar archive contains an unrecognized file: {}'.format(member.name))
else:
df = pd.read_csv(filename, dtype=str, **kwargs) # type: pd.DataFrame
self.load(df)
[docs] def load(self, df: pd.DataFrame) -> None:
"""
Load a panda.DataFrame, containing either nodes or edges, into a networkx.MultiDiGraph
Parameters
----------
df : pandas.DataFrame
Dataframe containing records that represent nodes or edges
"""
if 'subject' in df:
self.load_edges(df)
else:
self.load_nodes(df)
[docs] def load_nodes(self, df: pd.DataFrame) -> None:
"""
Load nodes from pandas.DataFrame into a networkx.MultiDiGraph
Parameters
----------
df : pandas.DataFrame
Dataframe containing records that represent nodes
"""
for obj in df.to_dict('record'):
self.load_node(obj)
[docs] def load_node(self, node: Dict) -> None:
"""
Load a node into a networkx.MultiDiGraph
Parameters
----------
node : dict
A node
"""
node = Transformer.validate_node(node)
kwargs = PandasTransformer._build_kwargs(node.copy())
if 'id' in kwargs:
n = kwargs['id']
self.graph.add_node(n, **kwargs)
else:
logging.info("Ignoring node with no 'id': {}".format(node))
[docs] def load_edges(self, df: pd.DataFrame) -> None:
"""
Load edges from pandas.DataFrame into a networkx.MultiDiGraph
Parameters
----------
df : pandas.DataFrame
Dataframe containing records that represent edges
"""
for obj in df.to_dict('record'):
self.load_edge(obj)
[docs] def load_edge(self, edge: Dict) -> None:
"""
Load an edge into a networkx.MultiDiGraph
Parameters
----------
edge : dict
An edge
"""
edge = Transformer.validate_edge(edge)
kwargs = PandasTransformer._build_kwargs(edge.copy())
if 'subject' in kwargs and 'object' in kwargs:
s = kwargs['subject']
o = kwargs['object']
key = generate_edge_key(s, kwargs['edge_label'], o)
self.graph.add_edge(s, o, key, **kwargs)
else:
logging.info("Ignoring edge with either a missing 'subject' or 'object': {}".format(kwargs))
[docs] def export_nodes(self) -> pd.DataFrame:
"""
Export nodes from networkx.MultiDiGraph as a pandas.DataFrame
Returns
-------
pandas.DataFrame
A Dataframe where each record corresponds to a node from the networkx.MultiDiGraph
"""
rows = []
for n, data in self.graph.nodes(data=True):
data = self.validate_node(data)
row = PandasTransformer._build_export_row(data.copy())
row['id'] = n
rows.append(row)
df = pd.DataFrame.from_records(rows)
return df
[docs] def export_edges(self) -> pd.DataFrame:
"""
Export edges from networkx.MultiDiGraph as a pandas.DataFrame
Returns
-------
pandas.DataFrame
A Dataframe where each record corresponds to an edge from the networkx.MultiDiGraph
"""
rows = []
for s, o, data in self.graph.edges(data=True):
data = self.validate_edge(data)
row = PandasTransformer._build_export_row(data.copy())
row['subject'] = s
row['object'] = o
rows.append(row)
df = pd.DataFrame.from_records(rows)
cols = df.columns.tolist()
cols = PandasTransformer._order_cols(cols)
df = df[cols]
return df
[docs] def save(self, filename: str, extension: str = 'csv', mode: str = 'w', **kwargs) -> str:
"""
Writes two files representing the node set and edge set of a networkx.MultiDiGraph,
and add them to a `.tar` archive.
Parameters
----------
filename: str
Name of tar archive file to create
extension: str
The output file format (``csv``, by default)
mode: str
Form of compression to use (``w``, by default, signifies no compression)
kwargs: dict
Any additional arguments
"""
if extension not in _extension_types:
raise Exception('Unsupported extension: ' + extension)
archive_name = "{}.{}".format(filename, _archive_format[mode])
delimiter = _extension_types[extension]
nodes_content = self.export_nodes().to_csv(sep=delimiter, index=False, escapechar="\\", doublequote=False)
edges_content = self.export_edges().to_csv(sep=delimiter, index=False, escapechar="\\", doublequote=False)
nodes_file_name = "{}_nodes.{}".format(filename, extension)
edges_file_name = "{}_edges.{}".format(filename, extension)
make_path(archive_name)
with tarfile.open(name=archive_name, mode=mode) as tar:
PandasTransformer._add_to_tar(tar, nodes_file_name, nodes_content)
PandasTransformer._add_to_tar(tar, edges_file_name, edges_content)
return filename
@staticmethod
def _build_kwargs(data: Dict) -> Dict:
"""
Sanitize key-value pairs in dictionary.
Parameters
----------
data: dict
A dictionary containing key-value pairs
Returns
-------
dict
A dictionary containing processed key-value pairs
"""
# remove numpy.nan
data = {k : v for k, v in data.items() if v is not np.nan}
for key, value in data.items():
# process value as a list if key is a multi-valued property
if key in _column_types:
if _column_types[key] == list:
if isinstance(value, (list, set, tuple)):
data[key] = list(value)
elif isinstance(value, str):
data[key] = value.split(LIST_DELIMITER)
else:
data[key] = [str(value)]
elif _column_types[key] == bool:
try:
data[key] = bool(value)
except:
data[key] = False
else:
data[key] = str(value)
return data
@staticmethod
def _build_export_row(data: Dict) -> Dict:
"""
Casts all values to primitive types like str or bool according to the
specified type in ``_column_types``. Lists become pipe delimited strings.
Parameters
----------
data: dict
A dictionary containing key-value pairs
Returns
-------
dict
A dictionary containing processed key-value pairs
"""
data = {k: v for k, v in data.items() if v is not np.nan}
for key, value in data.items():
if key in _column_types:
if _column_types[key] == list:
if isinstance(value, (list, set, tuple)):
data[key] = LIST_DELIMITER.join(value)
else:
data[key] = str(value)
elif _column_types[key] == bool:
try:
data[key] = bool(value)
except:
data[key] = False
else:
# some OWL files provide values that span multiple lines, which
# is parsed as-is by Rdflib. Escaping all new line characters.
value = value.replace('\n', '\\n')
data[key] = str(value)
else:
if type(data[key]) == list:
data[key] = LIST_DELIMITER.join(value)
elif type(data[key]) == bool:
try:
data[key] = bool(value)
except:
data[key] = False
else:
value = value.replace('\n', '\\n')
data[key] = str(value)
return data
@staticmethod
def _order_cols(cols: List[str]) -> List[str]:
"""
Arrange columns in a defined order.
Parameters
----------
cols: list
A list with elements in any order
Returns
-------
list
A list with elements in a particular order
"""
ORDER = ['id', 'subject', 'predicate', 'object', 'relation']
cols2 = []
for c in ORDER:
if c in cols:
cols2.append(c)
cols.remove(c)
return cols2 + cols
@staticmethod
def _add_to_tar(tar: tarfile.TarFile, filename: str, filecontent: pd.DataFrame) -> None:
"""
Write file contents to a given filename and add the file
to a specified tar archive.
Parameters
----------
tar: tarfile.TarFile
Tar archive handle
filename: str
Name of file to add to the archive
filecontent: pandas.DataFrame
DataFrame containing data to write to filename
"""
content = filecontent.encode()
with TemporaryFile() as tmp:
tmp.write(content)
tmp.seek(0)
info = tarfile.TarInfo(name=filename)
info.size = len(content)
tar.addfile(tarinfo=info, fileobj=tmp)