Warning

This document is for an in-development version of Galaxy. You can alternatively view this page in the latest release if it exists or view the top of the latest release's documentation.

Source code for galaxy.datatypes.sniff

"""
File format detector
"""

import bz2
import gzip
import io
import logging
import os
import re
import shutil
import struct
import tempfile
import zipfile
from functools import partial
from typing import (
    Callable,
    Dict,
    IO,
    Iterable,
    NamedTuple,
    Optional,
    TYPE_CHECKING,
    Union,
)

from typing_extensions import Protocol

from galaxy.files.uris import stream_url_to_file as files_stream_url_to_file
from galaxy.util import (
    compression_utils,
    file_reader,
    is_binary,
)
from galaxy.util.checkers import (
    check_html,
    COMPRESSION_CHECK_FUNCTIONS,
    is_tar,
)
from galaxy.util.path import StrPath

try:
    import pylibmagic  # noqa: F401  # isort:skip
except ImportError:
    # Not available in conda, but docker image contains libmagic
    pass
import magic  # isort:skip

if TYPE_CHECKING:
    from .data import Data

log = logging.getLogger(__name__)

SNIFF_PREFIX_BYTES = int(os.environ.get("GALAXY_SNIFF_PREFIX_BYTES", None) or 2**20)
BINARY_MIMETYPES = {"application/pdf", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"}


[docs]def get_test_fname(fname): """Returns test data filename""" path = os.path.dirname(__file__) full_path = os.path.join(path, "test", fname) assert os.path.isfile(full_path), f"{full_path} is not a file" return full_path
[docs]def sniff_with_cls(cls, fname): path = get_test_fname(fname) try: return bool(cls().sniff(path)) except Exception: return False
stream_url_to_file = partial(files_stream_url_to_file, prefix="gx_url_paste")
[docs]def handle_composite_file(datatype, src_path, extra_files, name, is_binary, tmp_dir, tmp_prefix, upload_opts): if not is_binary: if upload_opts.get("space_to_tab"): convert_newlines_sep2tabs(src_path, tmp_dir=tmp_dir, tmp_prefix=tmp_prefix) else: convert_newlines(src_path, tmp_dir=tmp_dir, tmp_prefix=tmp_prefix) file_output_path = os.path.join(extra_files, name) shutil.move(src_path, file_output_path) # groom the dataset file content if required by the corresponding datatype definition if datatype and datatype.dataset_content_needs_grooming(file_output_path): datatype.groom_dataset_content(file_output_path)
[docs]class ConvertResult(NamedTuple): line_count: int converted_path: Optional[str] converted_newlines: bool converted_regex: bool
[docs]class ConvertFunction(Protocol): def __call__( self, fname: str, in_place: bool = True, tmp_dir: Optional[str] = None, tmp_prefix: Optional[str] = "gxupload" ) -> ConvertResult: ...
[docs]def convert_newlines( fname: str, in_place: bool = True, tmp_dir: Optional[str] = None, tmp_prefix: Optional[str] = "gxupload", block_size: int = 128 * 1024, regexp=None, ) -> ConvertResult: """ Converts in place a file from universal line endings to Posix line endings. """ i = 0 converted_newlines = False converted_regex = False NEWLINE_BYTE = 10 CR_BYTE = 13 with tempfile.NamedTemporaryFile(mode="wb", prefix=tmp_prefix, dir=tmp_dir, delete=False) as fp, open( fname, mode="rb" ) as fi: last_char = None block = fi.read(block_size) last_block = b"" while block: if last_char == CR_BYTE and block.startswith(b"\n"): # Last block ended with CR, new block startswith newline. # Since we replace CR with newline in the previous iteration we skip the first byte block = block[1:] if block: last_char = block[-1] if b"\r" in block: block = block.replace(b"\r\n", b"\n").replace(b"\r", b"\n") converted_newlines = True if regexp: split_block = regexp.split(block) if len(split_block) > 1: converted_regex = True block = b"\t".join(split_block) fp.write(block) i += block.count(b"\n") last_block = block block = fi.read(block_size) if last_block and last_block[-1] != NEWLINE_BYTE: converted_newlines = True i += 1 fp.write(b"\n") if in_place: shutil.move(fp.name, fname) # Return number of lines in file. return ConvertResult(i, None, converted_newlines, converted_regex) else: return ConvertResult(i, fp.name, converted_newlines, converted_regex)
[docs]def convert_sep2tabs( fname: str, in_place: bool = True, tmp_dir: Optional[str] = None, tmp_prefix: Optional[str] = "gxupload", block_size: int = 128 * 1024, ): """ Transforms in place a 'sep' separated file to a tab separated one """ patt: bytes = rb"[^\S\r\n]+" regexp = re.compile(patt) i = 0 converted_newlines = False converted_regex = False with tempfile.NamedTemporaryFile(mode="wb", prefix=tmp_prefix, dir=tmp_dir, delete=False) as fp, open( fname, mode="rb" ) as fi: block = fi.read(block_size) while block: if block: split_block = regexp.split(block) if len(split_block) > 1: converted_regex = True block = b"\t".join(split_block) fp.write(block) i += block.count(b"\n") or block.count(b"\r") block = fi.read(block_size) if in_place: shutil.move(fp.name, fname) # Return number of lines in file. return ConvertResult(i, None, converted_newlines, converted_regex) else: return ConvertResult(i, fp.name, converted_newlines, converted_regex)
[docs]def convert_newlines_sep2tabs( fname: str, in_place: bool = True, tmp_dir: Optional[str] = None, tmp_prefix: Optional[str] = "gxupload" ) -> ConvertResult: """ Converts newlines in a file to posix newlines and replaces spaces with tabs. """ patt: bytes = rb"[^\S\n]+" regexp = re.compile(patt) return convert_newlines(fname, in_place, tmp_dir, tmp_prefix, regexp=regexp)
[docs]def iter_headers(fname_or_file_prefix, sep, count=60, comment_designator=None): idx = 0 if isinstance(fname_or_file_prefix, FilePrefix): file_iterator = fname_or_file_prefix.line_iterator() else: file_iterator = compression_utils.get_fileobj(fname_or_file_prefix) for line in file_iterator: line = line.rstrip("\n\r") if comment_designator is not None and comment_designator != "" and line.startswith(comment_designator): continue yield line.split(sep) idx += 1 if idx == count: break
[docs]def validate_tabular(fname_or_file_prefix, validate_row, sep, comment_designator=None): for row in iter_headers(fname_or_file_prefix, sep, count=-1, comment_designator=comment_designator): validate_row(row)
[docs]def get_headers(fname_or_file_prefix, sep, count=60, comment_designator=None): """ Returns a list with the first 'count' lines split by 'sep', ignoring lines starting with 'comment_designator' >>> fname = get_test_fname('complete.bed') >>> get_headers(fname,'\\t') == [['chr7', '127475281', '127491632', 'NM_000230', '0', '+', '127486022', '127488767', '0', '3', '29,172,3225,', '0,10713,13126,'], ['chr7', '127486011', '127488900', 'D49487', '0', '+', '127486022', '127488767', '0', '2', '155,490,', '0,2399']] True >>> fname = get_test_fname('test.gff') >>> get_headers(fname, '\\t', count=5, comment_designator='#') == [[''], ['chr7', 'bed2gff', 'AR', '26731313', '26731437', '.', '+', '.', 'score'], ['chr7', 'bed2gff', 'AR', '26731491', '26731536', '.', '+', '.', 'score'], ['chr7', 'bed2gff', 'AR', '26731541', '26731649', '.', '+', '.', 'score'], ['chr7', 'bed2gff', 'AR', '26731659', '26731841', '.', '+', '.', 'score']] True """ return list( iter_headers( fname_or_file_prefix=fname_or_file_prefix, sep=sep, count=count, comment_designator=comment_designator ) )
[docs]def is_column_based(fname_or_file_prefix, sep="\t", skip=0): """ Checks whether the file is column based with respect to a separator (defaults to tab separator). >>> fname = get_test_fname('test.gff') >>> is_column_based(fname) True >>> fname = get_test_fname('test_tab.bed') >>> is_column_based(fname) True >>> is_column_based(fname, sep=' ') False >>> fname = get_test_fname('test_space.txt') >>> is_column_based(fname) False >>> is_column_based(fname, sep=' ') True >>> fname = get_test_fname('test_ensembl.tabular') >>> is_column_based(fname) True >>> fname = get_test_fname('test_tab1.tabular') >>> is_column_based(fname, sep=' ', skip=0) False >>> fname = get_test_fname('test_tab1.tabular') >>> is_column_based(fname) True """ if getattr(fname_or_file_prefix, "binary", None) is True: return False try: headers = get_headers(fname_or_file_prefix, sep, comment_designator="#")[skip:] except UnicodeDecodeError: return False count = 0 if not headers: return False for hdr in headers: if hdr and hdr != [""]: if count: if len(hdr) != count: return False else: count = len(hdr) if count < 2: return False return count >= 2
[docs]def guess_ext(fname_or_file_prefix: Union[str, "FilePrefix"], sniff_order, is_binary=None, auto_decompress=True): """ Returns an extension that can be used in the datatype factory to generate a data for the 'fname' file >>> from galaxy.datatypes.registry import example_datatype_registry_for_sample >>> datatypes_registry = example_datatype_registry_for_sample() >>> sniff_order = datatypes_registry.sniff_order >>> fname = get_test_fname('empty.txt') >>> guess_ext(fname, sniff_order) 'txt' >>> fname = get_test_fname('megablast_xml_parser_test1.blastxml') >>> guess_ext(fname, sniff_order) 'blastxml' >>> fname = get_test_fname('1.psl') >>> guess_ext(fname, sniff_order) 'psl' >>> fname = get_test_fname('2.psl') >>> guess_ext(fname, sniff_order) 'psl' >>> fname = get_test_fname('interval.interval') >>> guess_ext(fname, sniff_order) 'interval' >>> fname = get_test_fname('interv1.bed') >>> guess_ext(fname, sniff_order) 'bed' >>> fname = get_test_fname('test_tab.bed') >>> guess_ext(fname, sniff_order) 'bed' >>> fname = get_test_fname('sequence.maf') >>> guess_ext(fname, sniff_order) 'maf' >>> fname = get_test_fname('sequence.fasta') >>> guess_ext(fname, sniff_order) 'fasta' >>> fname = get_test_fname('1.genbank') >>> guess_ext(fname, sniff_order) 'genbank' >>> fname = get_test_fname('1.genbank.gz') >>> guess_ext(fname, sniff_order) 'genbank.gz' >>> fname = get_test_fname('file.html') >>> guess_ext(fname, sniff_order) 'html' >>> fname = get_test_fname('test.gtf') >>> guess_ext(fname, sniff_order) 'gtf' >>> fname = get_test_fname('test.gff') >>> guess_ext(fname, sniff_order) 'gff' >>> fname = get_test_fname('gff.gff3') >>> guess_ext(fname, sniff_order) 'gff3' >>> fname = get_test_fname('2.txt') >>> guess_ext(fname, sniff_order) 'txt' >>> fname = get_test_fname('test_tab2.tabular') >>> guess_ext(fname, sniff_order) 'tabular' >>> fname = get_test_fname('3.txt') >>> guess_ext(fname, sniff_order) 'txt' >>> fname = get_test_fname('test_tab1.tabular') >>> guess_ext(fname, sniff_order) 'tabular' >>> fname = get_test_fname('alignment.lav') >>> guess_ext(fname, sniff_order) 'lav' >>> fname = get_test_fname('1.sff') >>> guess_ext(fname, sniff_order) 'sff' >>> fname = get_test_fname('1.bam') >>> guess_ext(fname, sniff_order) 'bam' >>> fname = get_test_fname('3unsorted.bam') >>> guess_ext(fname, sniff_order) 'unsorted.bam' >>> fname = get_test_fname('test.idpdb') >>> guess_ext(fname, sniff_order) 'idpdb' >>> fname = get_test_fname('test.mz5') >>> guess_ext(fname, sniff_order) 'h5' >>> fname = get_test_fname('issue1818.tabular') >>> guess_ext(fname, sniff_order) 'tabular' >>> fname = get_test_fname('drugbank_drugs.cml') >>> guess_ext(fname, sniff_order) 'cml' >>> fname = get_test_fname('q.fps') >>> guess_ext(fname, sniff_order) 'fps' >>> fname = get_test_fname('drugbank_drugs.inchi') >>> guess_ext(fname, sniff_order) 'inchi' >>> fname = get_test_fname('drugbank_drugs.mol2') >>> guess_ext(fname, sniff_order) 'mol2' >>> fname = get_test_fname('drugbank_drugs.sdf') >>> guess_ext(fname, sniff_order) 'sdf' >>> fname = get_test_fname('5e5z.pdb') >>> guess_ext(fname, sniff_order) 'pdb' >>> fname = get_test_fname('Si_uppercase.cell') >>> guess_ext(fname, sniff_order) 'cell' >>> fname = get_test_fname('Si_lowercase.cell') >>> guess_ext(fname, sniff_order) 'cell' >>> fname = get_test_fname('Si.cif') >>> guess_ext(fname, sniff_order) 'cif' >>> fname = get_test_fname('LaMnO3.cif') >>> guess_ext(fname, sniff_order) 'cif' >>> fname = get_test_fname('Si.xyz') >>> guess_ext(fname, sniff_order) 'xyz' >>> fname = get_test_fname('Si_multi.xyz') >>> guess_ext(fname, sniff_order) 'xyz' >>> fname = get_test_fname('Si.extxyz') >>> guess_ext(fname, sniff_order) 'extxyz' >>> fname = get_test_fname('Si.castep') >>> guess_ext(fname, sniff_order) 'castep' >>> fname = get_test_fname('test.fits') >>> guess_ext(fname, sniff_order) 'fits' >>> fname = get_test_fname('Si.param') >>> guess_ext(fname, sniff_order) 'param' >>> fname = get_test_fname('Si.den_fmt') >>> guess_ext(fname, sniff_order) 'den_fmt' >>> fname = get_test_fname('ethanol.magres') >>> guess_ext(fname, sniff_order) 'magres' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.otu') >>> guess_ext(fname, sniff_order) 'mothur.otu' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.lower.dist') >>> guess_ext(fname, sniff_order) 'mothur.lower.dist' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.square.dist') >>> guess_ext(fname, sniff_order) 'mothur.square.dist' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.pair.dist') >>> guess_ext(fname, sniff_order) 'mothur.pair.dist' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.freq') >>> guess_ext(fname, sniff_order) 'mothur.freq' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.quan') >>> guess_ext(fname, sniff_order) 'mothur.quan' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.ref.taxonomy') >>> guess_ext(fname, sniff_order) 'mothur.ref.taxonomy' >>> fname = get_test_fname('mothur_datatypetest_true.mothur.axes') >>> guess_ext(fname, sniff_order) 'mothur.axes' >>> guess_ext(get_test_fname('infernal_model.cm'), sniff_order) 'cm' >>> fname = get_test_fname('1.gg') >>> guess_ext(fname, sniff_order) 'gg' >>> fname = get_test_fname('diamond_db.dmnd') >>> guess_ext(fname, sniff_order) 'dmnd' >>> fname = get_test_fname('1.excel.xls') >>> guess_ext(fname, sniff_order, is_binary=True) 'excel.xls' >>> fname = get_test_fname('biom2_sparse_otu_table_hdf5.biom2') >>> guess_ext(fname, sniff_order) 'biom2' >>> fname = get_test_fname('454Score.pdf') >>> guess_ext(fname, sniff_order) 'pdf' >>> fname = get_test_fname('1.obo') >>> guess_ext(fname, sniff_order) 'obo' >>> fname = get_test_fname('1.arff') >>> guess_ext(fname, sniff_order) 'arff' >>> fname = get_test_fname('1.afg') >>> guess_ext(fname, sniff_order) 'afg' >>> fname = get_test_fname('1.owl') >>> guess_ext(fname, sniff_order) 'owl' >>> fname = get_test_fname('Acanium.snaphmm') >>> guess_ext(fname, sniff_order) 'snaphmm' >>> fname = get_test_fname('wiggle.wig') >>> guess_ext(fname, sniff_order) 'wig' >>> fname = get_test_fname('example.iqtree') >>> guess_ext(fname, sniff_order) 'iqtree' >>> fname = get_test_fname('1.stockholm') >>> guess_ext(fname, sniff_order) 'stockholm' >>> fname = get_test_fname('1.xmfa') >>> guess_ext(fname, sniff_order) 'xmfa' >>> fname = get_test_fname('test.blib') >>> guess_ext(fname, sniff_order) 'blib' >>> fname = get_test_fname('test_strict_interleaved.phylip') >>> guess_ext(fname, sniff_order) 'phylip' >>> fname = get_test_fname('test_relaxed_interleaved.phylip') >>> guess_ext(fname, sniff_order) 'phylip' >>> fname = get_test_fname('1.smat') >>> guess_ext(fname, sniff_order) 'smat' >>> fname = get_test_fname('1.ttl') >>> guess_ext(fname, sniff_order) 'ttl' >>> fname = get_test_fname('1.hdt') >>> guess_ext(fname, sniff_order, is_binary=True) 'hdt' >>> fname = get_test_fname('1.phyloxml') >>> guess_ext(fname, sniff_order) 'phyloxml' >>> fname = get_test_fname('1.dzi') >>> guess_ext(fname, sniff_order) 'dzi' >>> fname = get_test_fname('1.tiff') >>> guess_ext(fname, sniff_order) 'tiff' >>> fname = get_test_fname('1.fastqsanger.gz') >>> guess_ext(fname, sniff_order) # See test_datatype_registry for more compressed type tests. 'fastqsanger.gz' >>> fname = get_test_fname('1.mtx') >>> guess_ext(fname, sniff_order) 'mtx' >>> fname = get_test_fname('mc_preprocess_summ.metacyto_summary.txt') >>> guess_ext(fname, sniff_order) 'metacyto_summary.txt' >>> fname = get_test_fname('Accuri_C6_A01_H2O.fcs') >>> guess_ext(fname, sniff_order) 'fcs' >>> fname = get_test_fname('1imzml') >>> guess_ext(fname, sniff_order) # This test case is ensuring doesn't throw exception, actual value could change if non-utf encoding handling improves. 'data' >>> fname = get_test_fname('too_many_comments_gff3.tabular') >>> guess_ext(fname, sniff_order) # It's a VCF but is sniffed as tabular because of the limit on the number of header lines we read 'tabular' """ file_prefix = _get_file_prefix(fname_or_file_prefix, auto_decompress=auto_decompress) file_ext = run_sniffers_raw(file_prefix, sniff_order) # Ugly hack for tsv vs tabular sniffing, we want to prefer tabular # to tsv but it doesn't have a sniffer - is TSV was sniffed just check # if it is an okay tabular and use that instead. if file_ext == "tsv": if is_column_based(file_prefix, "\t", 1): file_ext = "tabular" if file_ext is not None: return file_ext # skip header check if data is already known to be binary if file_prefix.binary: return file_ext or "binary" try: get_headers(file_prefix, None) except UnicodeDecodeError: return "data" # default data type file extension if is_column_based(file_prefix, "\t", 1): return "tabular" # default tabular data type file extension return "txt" # default text data type file extension
[docs]def guess_ext_from_file_name(fname, registry, requested_ext="auto"): if requested_ext != "auto": return requested_ext return registry.get_datatype_from_filename(fname).file_ext
[docs]class FilePrefix:
[docs] def __init__(self, filename, auto_decompress=True): non_utf8_error = None compressed_format = None contents_header_bytes = None contents_header = None # First MAX_BYTES of the file. truncated = False # A future direction to optimize sniffing even more for sniffers at the top of the list # is to lazy load contents_header based on what interface is requested. For instance instead # of returning a StringIO directly in string_io() return an object that reads the contents and # populates contents_header while providing a StringIO-like interface until the file is read # but then would fallback to native string_io() try: compressed_format, f = compression_utils.get_fileobj_raw(filename, "rb") try: contents_header_bytes = f.read(SNIFF_PREFIX_BYTES) truncated = len(contents_header_bytes) == SNIFF_PREFIX_BYTES contents_header = contents_header_bytes.decode("utf-8") finally: f.close() except UnicodeDecodeError as e: non_utf8_error = e self.auto_decompress = auto_decompress self.truncated = truncated self.filename = filename self.non_utf8_error = non_utf8_error file_magic = magic.detect_from_content(contents_header_bytes) self.encoding = file_magic.encoding self.mime_type = file_magic.mime_type self.compressed_mime_type = None self.compressed_encoding = None if compressed_format: compressed_magic = magic.detect_from_filename(filename) self.compressed_mime_type = compressed_magic.mime_type self.compressed_encoding = compressed_magic.encoding self.compressed_format = compressed_format self.contents_header = contents_header self.contents_header_bytes = contents_header_bytes self._is_binary = None self._file_size = None
@property def binary(self): if self._is_binary is None: self._is_binary = bool({self.mime_type, self.compressed_mime_type} & BINARY_MIMETYPES) or is_binary( self.contents_header_bytes ) if ( not self._is_binary and self.encoding == "binary" and self.non_utf8_error or not self.auto_decompress and self.compressed_encoding == "binary" ): # Try harder ... if we have a non-utf-8 error, the file could be latin-1 encoded, # but magic would recognize this and set the encoding appropriately self._is_binary = True return self._is_binary @property def file_size(self): if self._file_size is None: self._file_size = os.path.getsize(self.filename) return self._file_size
[docs] def string_io(self) -> io.StringIO: if self.non_utf8_error is not None: raise self.non_utf8_error rval = io.StringIO(self.contents_header) return rval
[docs] def text_io(self, *args, **kwargs) -> io.TextIOWrapper: return io.TextIOWrapper(io.BytesIO(self.contents_header_bytes), *args, **kwargs)
[docs] def startswith(self, prefix): return self.string_io().read(len(prefix)) == prefix
[docs] def line_iterator(self): s = self.string_io() s_len = len(s.getvalue()) for line in iter(s.readline, ""): if line.endswith("\n") or line.endswith("\r"): yield line elif s.tell() == s_len and not self.truncated: # At the end, return the last line if it wasn't truncated when reading it in. yield line
# Convenience wrappers around contents_header, shielding contents_header means we can # potentially do a better job lazy loading this data later on.
[docs] def search(self, pattern): return pattern.search(self.contents_header)
[docs] def search_str(self, query_str): return query_str in self.contents_header
[docs] def magic_header(self, pattern): """ Unpack header and get first element """ size = struct.calcsize(pattern) header_bytes = self.contents_header_bytes[:size] if len(header_bytes) < size: return None return struct.unpack(pattern, header_bytes)[0]
[docs] def startswith_bytes(self, test_bytes): return self.contents_header_bytes.startswith(test_bytes)
def _get_file_prefix(filename_or_file_prefix: Union[str, FilePrefix], auto_decompress: bool = True) -> FilePrefix: if not isinstance(filename_or_file_prefix, FilePrefix): return FilePrefix(filename_or_file_prefix, auto_decompress=auto_decompress) return filename_or_file_prefix
[docs]def run_sniffers_raw(file_prefix: FilePrefix, sniff_order: Iterable["Data"]): """Run through sniffers specified by sniff_order, return None of None match.""" fname = file_prefix.filename file_ext = None for datatype in sniff_order: """ Some classes may not have a sniff function, which is ok. In fact, Binary, Data, Tabular and Text are examples of classes that should never have a sniff function. Since these classes are default classes, they contain few rules to filter out data of other formats, so they should be called from this function after all other datatypes in sniff_order have not been successfully discovered. """ datatype_compressed = getattr(datatype, "compressed", False) if datatype_compressed and not file_prefix.compressed_format and not datatype.file_ext.endswith(".tar"): # we don't auto-detect tar as compressed continue if not datatype_compressed and file_prefix.compressed_format: continue if file_prefix.binary != datatype.is_binary and not datatype.is_binary == "maybe": # Binary detection doesn't match datatype ... compressed_data_for_compressed_text_datatype = ( file_prefix.binary and file_prefix.compressed_format and datatype_compressed and not datatype.is_binary ) if not compressed_data_for_compressed_text_datatype: # ... and mismatch is not due to compressed text data for a compressed text datatype continue try: if hasattr(datatype, "sniff_prefix"): datatype_compressed_format = getattr(datatype, "compressed_format", None) if file_prefix.compressed_format and datatype_compressed_format: # Compare the compressed format detected # to the expected. if file_prefix.compressed_format != datatype_compressed_format: continue if datatype.sniff_prefix(file_prefix): file_ext = datatype.file_ext break elif hasattr(datatype, "sniff") and datatype.sniff(fname): file_ext = datatype.file_ext break except Exception: pass return file_ext
[docs]def zip_single_fileobj(path: StrPath) -> IO[bytes]: z = zipfile.ZipFile(path) for name in z.namelist(): if not name.endswith("/"): return z.open(name) raise ValueError("No file present in the zip file")
[docs]def build_sniff_from_prefix(klass): # Build and attach a sniff function to this class (klass) from the sniff_prefix function # expected to be defined for the class. def auto_sniff(self, filename): file_prefix = FilePrefix(filename) datatype_compressed = getattr(self, "compressed", False) if file_prefix.compressed_format and not datatype_compressed: return False if datatype_compressed: if not file_prefix.compressed_format: # This not a compressed file we are looking but the type expects it to be # must return False. return False if hasattr(self, "compressed_format"): if self.compressed_format != file_prefix.compressed_format: return False return self.sniff_prefix(file_prefix) klass.sniff = auto_sniff return klass
[docs]def disable_parent_class_sniffing(klass): klass.sniff = lambda self, filename: False klass.sniff_prefix = lambda self, file_prefix: False return klass
[docs]class HandleCompressedFileResponse(NamedTuple): is_valid: bool ext: str uncompressed_path: str compressed_type: Optional[str] is_compressed: Optional[bool]
[docs]def handle_compressed_file( file_prefix: FilePrefix, datatypes_registry, ext: str = "auto", tmp_prefix: Optional[str] = "sniff_uncompress_", tmp_dir: Optional[str] = None, in_place: bool = False, check_content: bool = True, ) -> HandleCompressedFileResponse: """ Check uploaded files for compression, check compressed file contents, and uncompress if necessary. Supports GZip, BZip2, and the first file in a Zip file. For performance reasons, the temporary file used for uncompression is located in the same directory as the input/output file. This behavior can be changed with the `tmp_dir` param. ``ext`` as returned will only be changed from the ``ext`` input param if the param was an autodetect type (``auto``) and the file was sniffed as a keep-compressed datatype. ``is_valid`` as returned will only be set if the file is compressed and contains invalid contents (or the first file in the case of a zip file), this is so lengthy decompression can be bypassed if there is invalid content in the first 32KB. Otherwise the caller should be checking content. """ CHUNK_SIZE = 2**20 # 1Mb is_compressed = False compressed_type = None keep_compressed = False is_valid = False filename = file_prefix.filename uncompressed_path = filename tmp_dir = tmp_dir or os.path.dirname(filename) check_compressed_function = COMPRESSION_CHECK_FUNCTIONS.get(file_prefix.compressed_format) if check_compressed_function: is_compressed, is_valid = check_compressed_function(filename, check_content=check_content) compressed_type = file_prefix.compressed_format if is_compressed and is_valid: if ext in AUTO_DETECT_EXTENSIONS: # attempt to sniff for a keep-compressed datatype (observing the sniff order) sniff_datatypes = filter(lambda d: getattr(d, "compressed", False), datatypes_registry.sniff_order) sniffed_ext = run_sniffers_raw(file_prefix, sniff_datatypes) if sniffed_ext: ext = sniffed_ext keep_compressed = True else: datatype = datatypes_registry.get_datatype_by_extension(ext) keep_compressed = getattr(datatype, "compressed", False) # don't waste time decompressing if we sniff invalid contents if is_compressed and is_valid and file_prefix.auto_decompress and not keep_compressed: assert compressed_type # Tell type checker is_compressed will only be true if compressed_type is also set. with tempfile.NamedTemporaryFile(prefix=tmp_prefix, dir=tmp_dir, delete=False) as uncompressed: with DECOMPRESSION_FUNCTIONS[compressed_type](filename) as compressed_file: # TODO: it'd be ideal to convert to posix newlines and space-to-tab here as well try: for chunk in file_reader(compressed_file, CHUNK_SIZE): if not chunk: break uncompressed.write(chunk) except OSError as e: os.remove(uncompressed.name) raise OSError( f"Problem uncompressing {compressed_type} data, please try retrieving the data uncompressed: {e}" ) finally: is_compressed = False uncompressed_path = uncompressed.name if in_place: # Replace the compressed file with the uncompressed file shutil.move(uncompressed_path, filename) uncompressed_path = filename elif not is_compressed or not check_content: is_valid = True return HandleCompressedFileResponse(is_valid, ext, uncompressed_path, compressed_type, is_compressed)
[docs]def handle_uploaded_dataset_file(filename, *args, **kwds) -> str: """Legacy wrapper about handle_uploaded_dataset_file_internal for tools using it.""" file_prefix = FilePrefix(filename) return handle_uploaded_dataset_file_internal(file_prefix, *args, **kwds)[0]
[docs]class HandleUploadedDatasetFileInternalResponse(NamedTuple): ext: str converted_path: str compressed_type: Optional[str] converted_newlines: bool converted_spaces: bool
[docs]def convert_function(convert_to_posix_lines, convert_spaces_to_tabs) -> ConvertFunction: assert convert_to_posix_lines or convert_spaces_to_tabs if convert_spaces_to_tabs and convert_to_posix_lines: convert_fxn = convert_newlines_sep2tabs elif convert_to_posix_lines: convert_fxn = convert_newlines else: convert_fxn = convert_sep2tabs return convert_fxn
[docs]def handle_uploaded_dataset_file_internal( file_prefix: FilePrefix, datatypes_registry, ext: str = "auto", tmp_prefix: Optional[str] = "sniff_upload_", tmp_dir: Optional[str] = None, in_place: bool = False, check_content: bool = True, is_binary: Optional[bool] = None, uploaded_file_ext: Optional[str] = None, convert_to_posix_lines: Optional[bool] = None, convert_spaces_to_tabs: Optional[bool] = None, ) -> HandleUploadedDatasetFileInternalResponse: is_valid, ext, converted_path, compressed_type, is_compressed = handle_compressed_file( file_prefix, datatypes_registry, ext=ext, tmp_prefix=tmp_prefix, tmp_dir=tmp_dir, in_place=in_place, check_content=check_content, ) converted_newlines = False converted_spaces = False try: if not is_valid: if is_tar(converted_path): raise InappropriateDatasetContentError("TAR file uploads are not supported") raise InappropriateDatasetContentError("The uploaded compressed file contains invalid content") is_binary = file_prefix.binary guessed_ext = ext if ext in AUTO_DETECT_EXTENSIONS: # TODO: skip this if we haven't actually converted the dataset guessed_ext = guess_ext( converted_path, sniff_order=datatypes_registry.sniff_order, auto_decompress=file_prefix.auto_decompress, ) if not is_binary and not is_compressed and (convert_to_posix_lines or convert_spaces_to_tabs): # Convert universal line endings to Posix line endings, spaces to tabs (if desired) convert_fxn = convert_function(convert_to_posix_lines, convert_spaces_to_tabs) line_count, _converted_path, converted_newlines, converted_spaces = convert_fxn( converted_path, in_place=in_place, tmp_dir=tmp_dir, tmp_prefix=tmp_prefix ) if not in_place: if converted_path and file_prefix.filename != converted_path: os.unlink(converted_path) assert _converted_path converted_path = _converted_path if ext in AUTO_DETECT_EXTENSIONS: ext = guess_ext(converted_path, sniff_order=datatypes_registry.sniff_order) else: ext = guessed_ext if not is_binary and check_content and check_html(converted_path): raise InappropriateDatasetContentError("The uploaded file contains invalid HTML content") except Exception: if file_prefix.filename != converted_path: os.unlink(converted_path) raise return HandleUploadedDatasetFileInternalResponse( ext, converted_path, compressed_type, converted_newlines, converted_spaces )
AUTO_DETECT_EXTENSIONS = ["auto"] # should 'data' also cause auto detect? DECOMPRESSION_FUNCTIONS: Dict[str, Callable] = dict(gzip=gzip.GzipFile, bz2=bz2.BZ2File, zip=zip_single_fileobj)
[docs]class InappropriateDatasetContentError(Exception): pass