Source code for beaker.services.dataset

import io
import os
import urllib.parse
from datetime import datetime
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    ClassVar,
    Dict,
    Generator,
    List,
    Optional,
    Tuple,
    Union,
)

from ..aliases import PathOrStr
from ..data_model import *
from ..exceptions import *
from ..util import log_and_wait, path_is_relative_to, retriable
from .service_client import ServiceClient

if TYPE_CHECKING:
    from requests import Response
    from rich.progress import Progress, TaskID


is_canceled = None


[docs]class DatasetClient(ServiceClient): """ Accessed via :data:`Beaker.dataset <beaker.Beaker.dataset>`. """ HEADER_UPLOAD_ID = "Upload-ID" HEADER_UPLOAD_LENGTH = "Upload-Length" HEADER_UPLOAD_OFFSET = "Upload-Offset" HEADER_DIGEST = "Digest" HEADER_LAST_MODIFIED = "Last-Modified" HEADER_CONTENT_LENGTH = "Content-Length" REQUEST_SIZE_LIMIT: ClassVar[int] = 32 * 1024 * 1024 DOWNLOAD_CHUNK_SIZE: ClassVar[int] = 10 * 1024 """ The default buffer size for downloads. """
[docs] def get(self, dataset: str) -> Dataset: """ Get info about a dataset. :param dataset: The dataset ID or name. :raises DatasetNotFound: If the dataset can't be found. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ def _get(id: str) -> Dataset: return Dataset.from_json( self.request( f"datasets/{self.url_quote(id)}", exceptions_for_status={404: DatasetNotFound(self._not_found_err_msg(id))}, ).json() ) try: # Could be a dataset ID or full name, so we try that first. return _get(dataset) except DatasetNotFound: if "/" not in dataset: # Try with adding the account name. try: return _get(f"{self.beaker.account.name}/{dataset}") except DatasetNotFound: pass # Try searching the default workspace. if self.config.default_workspace is not None: matches = self.beaker.workspace.datasets(match=dataset, limit=1) if matches: return matches[0] raise
[docs] def create( self, name: str, *sources: PathOrStr, target: Optional[PathOrStr] = None, workspace: Optional[str] = None, description: Optional[str] = None, force: bool = False, max_workers: Optional[int] = None, quiet: bool = False, commit: bool = True, strip_paths: bool = False, ) -> Dataset: """ Create a dataset with the source file(s). :param name: The name to assign to the new dataset. :param sources: Local source files or directories to upload to the dataset. :param target: If specified, all source files/directories will be uploaded under a directory of this name. :param workspace: The workspace to upload the dataset to. If not specified, :data:`Beaker.config.default_workspace <beaker.Config.default_workspace>` is used. :param description: Text description for the dataset. :param force: If ``True`` and a dataset by the given name already exists, it will be overwritten. :param max_workers: The maximum number of thread pool workers to use to upload files concurrently. :param quiet: If ``True``, progress won't be displayed. :param commit: Whether to commit the dataset after successful upload. :param strip_paths: If ``True``, all source files and directories will be uploaded under their name, not their path. E.g. the file "docs/source/index.rst" would be uploaded as just "index.rst", instead of "docs/source/index.rst". .. note:: This only applies to source paths that are children of the current working directory. If a source path is outside of the current working directory, it will always be uploaded under its name only. :raises ValueError: If the name is invalid. :raises DatasetConflict: If a dataset by that name already exists and ``force=False``. :raises UnexpectedEOFError: If a source is a directory and the contents of one of the directory's files changes while creating the dataset. :raises FileNotFoundError: If a source doesn't exist. :raises WorkspaceNotSet: If neither ``workspace`` nor :data:`Beaker.config.default_workspace <beaker.Config.default_workspace>` are set. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ self.validate_beaker_name(name) workspace_id = self.resolve_workspace(workspace).id # Create the dataset. def make_dataset() -> Dataset: return Dataset.from_json( self.request( "datasets", method="POST", query={"name": name}, data=DatasetSpec(workspace=workspace_id, description=description), exceptions_for_status={409: DatasetConflict(name)}, ).json() ) try: dataset_info = make_dataset() except DatasetConflict: if force: self.delete(f"{self.beaker.account.whoami().name}/{name}") dataset_info = make_dataset() else: raise assert dataset_info.storage is not None # Upload the file(s). if sources: self.sync( dataset_info, *sources, target=target, quiet=quiet, max_workers=max_workers, strip_paths=strip_paths, ) # Commit the dataset. if commit: self.commit(dataset_info.id) # Return info about the dataset. return self.get(dataset_info.id)
[docs] def commit(self, dataset: Union[str, Dataset]) -> Dataset: """ Commit the dataset. :param dataset: The dataset ID, name, or object. :raises DatasetNotFound: If the dataset can't be found. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset_id = self.resolve_dataset(dataset).id @retriable() def commit() -> Dataset: # It's okay to retry this because committing a dataset multiple # times does nothing. return Dataset.from_json( self.request( f"datasets/{self.url_quote(dataset_id)}", method="PATCH", data=DatasetPatch(commit=True), exceptions_for_status={404: DatasetNotFound(self._not_found_err_msg(dataset))}, ).json() ) return commit()
[docs] def fetch( self, dataset: Union[str, Dataset], target: Optional[PathOrStr] = None, prefix: Optional[str] = None, force: bool = False, max_workers: Optional[int] = None, quiet: bool = False, validate_checksum: bool = True, chunk_size: Optional[int] = None, ): """ Download a dataset. :param dataset: The dataset ID, name, or object. :param target: The target path to download fetched data to. Defaults to ``Path(.)``. :param prefix: Only download files that start with the given prefix. :param max_workers: The maximum number of thread pool workers to use to download files concurrently. :param force: If ``True``, existing local files will be overwritten. :param quiet: If ``True``, progress won't be displayed. :param validate_checksum: If ``True``, the checksum of every file downloaded will be verified. :param chunk_size: The size of the buffer (in bytes) to use while downloading each file. Defaults to :data:`DOWNLOAD_CHUNK_SIZE`. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetReadError: If the :data:`~beaker.data_model.dataset.Dataset.storage` hasn't been set. :raises FileExistsError: If ``force=False`` and an existing local file clashes with a file in the Beaker dataset. :raises ChecksumFailedError: If ``validate_checksum=True`` and the digest of one of the downloaded files doesn't match the expected digest. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset = self.resolve_dataset(dataset) if dataset.storage is None: # Might need to get dataset again if 'storage' hasn't been set yet. dataset = self.get(dataset.id) if dataset.storage is None: raise DatasetReadError(dataset.id) dataset_info = DatasetInfo.from_json( self.request( f"datasets/{dataset.id}/files", exceptions_for_status={404: DatasetNotFound(self._not_found_err_msg(dataset.id))}, ).json() ) total_bytes_to_download: int = dataset_info.size.bytes total_downloaded: int = 0 target = Path(target or Path(".")) target.mkdir(exist_ok=True, parents=True) from ..progress import get_sized_dataset_fetch_progress progress = get_sized_dataset_fetch_progress(quiet) with progress: bytes_task = progress.add_task("Downloading dataset") progress.update(bytes_task, total=total_bytes_to_download) import concurrent.futures import threading with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: global is_canceled is_canceled = threading.Event() download_futures = [] try: for file_info in dataset_info.page.data: if prefix is not None and not file_info.path.startswith(prefix): continue target_path = target / Path(file_info.path) if not force and target_path.exists(): raise FileExistsError(file_info.path) future = executor.submit( self._download_file, dataset, file_info, target_path, progress=progress, task_id=bytes_task, validate_checksum=validate_checksum, chunk_size=chunk_size, ) download_futures.append(future) for future in concurrent.futures.as_completed(download_futures): total_downloaded += future.result() except KeyboardInterrupt: self.logger.warning("Received KeyboardInterrupt, canceling download workers...") is_canceled.set() # type: ignore for future in download_futures: future.cancel() executor.shutdown(wait=True) raise progress.update(bytes_task, total=total_downloaded, completed=total_downloaded)
[docs] def stream_file( self, dataset: Union[str, Dataset], file: Union[str, FileInfo], offset: int = 0, length: int = -1, quiet: bool = False, validate_checksum: bool = True, chunk_size: Optional[int] = None, ) -> Generator[bytes, None, None]: """ Stream download the contents of a single file from a dataset. .. seealso:: :meth:`get_file()` is similar but returns the entire contents at once instead of a generator over the contents. :param dataset: The dataset ID, name, or object. :param file: The path of the file within the dataset or the corresponding :class:`~beaker.data_model.dataset.FileInfo` object. :param offset: Offset to start from, in bytes. :param length: Number of bytes to read. :param quiet: If ``True``, progress won't be displayed. :param validate_checksum: If ``True``, the checksum of the downloaded bytes will be verified. :param chunk_size: The size of the buffer (in bytes) to use while downloading each file. Defaults to :data:`DOWNLOAD_CHUNK_SIZE`. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetReadError: If the :data:`~beaker.data_model.dataset.Dataset.storage` hasn't been set. :raises FileNotFoundError: If the file doesn't exist in the dataset. :raises ChecksumFailedError: If ``validate_checksum=True`` and the digest of the downloaded bytes don't match the expected digest. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. :examples: >>> total_bytes = 0 >>> with open(tmp_path / squad_dataset_file_name, "wb") as f: ... for chunk in beaker.dataset.stream_file(squad_dataset_name, squad_dataset_file_name, quiet=True): ... total_bytes += f.write(chunk) """ dataset = self.resolve_dataset(dataset, ensure_storage=True) file_info = file if isinstance(file, FileInfo) else self.file_info(dataset, file) from ..progress import get_unsized_dataset_fetch_progress with get_unsized_dataset_fetch_progress(quiet=quiet) as progress: task_id = progress.add_task("Downloading", total=None) for bytes_chunk in self._stream_file( dataset, file_info, offset=offset, length=length, validate_checksum=validate_checksum, chunk_size=chunk_size, ): progress.update(task_id, advance=len(bytes_chunk)) yield bytes_chunk
[docs] def get_file( self, dataset: Union[str, Dataset], file: Union[str, FileInfo], offset: int = 0, length: int = -1, quiet: bool = False, validate_checksum: bool = True, chunk_size: Optional[int] = None, ) -> bytes: """ Download the contents of a single file from a dataset. .. seealso:: :meth:`stream_file()` is similar but returns a generator over the contents. :param dataset: The dataset ID, name, or object. :param file: The path of the file within the dataset or the corresponding :class:`~beaker.data_model.dataset.FileInfo` object. :param offset: Offset to start from, in bytes. :param length: Number of bytes to read. :param quiet: If ``True``, progress won't be displayed. :param validate_checksum: If ``True``, the checksum of the downloaded bytes will be verified. :param chunk_size: The size of the buffer (in bytes) to use while downloading each file. Defaults to :data:`DOWNLOAD_CHUNK_SIZE`. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetReadError: If the :data:`~beaker.data_model.dataset.Dataset.storage` hasn't been set. :raises FileNotFoundError: If the file doesn't exist in the dataset. :raises ChecksumFailedError: If ``validate_checksum=True`` and the digest of the downloaded bytes don't match the expected digest. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. :examples: >>> contents = beaker.dataset.get_file(squad_dataset_name, squad_dataset_file_name, quiet=True) """ @retriable(recoverable_errors=(RequestException, ChecksumFailedError)) def _get_file() -> bytes: return b"".join( self.stream_file( dataset, file, offset=offset, length=length, quiet=quiet, validate_checksum=validate_checksum, chunk_size=chunk_size, ) ) return _get_file()
[docs] def file_info(self, dataset: Union[str, Dataset], file_name: str) -> FileInfo: """ Get the :class:`~beaker.data_model.dataset.FileInfo` for a file in a dataset. :param dataset: The dataset ID, name, or object. :param file_name: The path of the file within the dataset. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetReadError: If the :data:`~beaker.data_model.dataset.Dataset.storage` hasn't been set. :raises FileNotFoundError: If the file doesn't exist in the dataset. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset = self.resolve_dataset(dataset, ensure_storage=True) assert dataset.storage is not None if dataset.storage.scheme == "fh": response = self.request( f"datasets/{dataset.storage.id}/files/{file_name}", method="HEAD", token=dataset.storage.token, base_url=dataset.storage.base_url, exceptions_for_status={404: FileNotFoundError(file_name)}, ) size_str = response.headers.get(self.HEADER_CONTENT_LENGTH) size = int(size_str) if size_str else None return FileInfo( path=file_name, digest=Digest.from_encoded(response.headers[self.HEADER_DIGEST]), updated=datetime.strptime( response.headers[self.HEADER_LAST_MODIFIED], "%a, %d %b %Y %H:%M:%S %Z" ), size=size, ) else: # TODO (epwalsh): make a HEAD request once Beaker supports that # (https://github.com/allenai/beaker/issues/2961) response = self.request( f"datasets/{dataset.id}/files/{urllib.parse.quote(file_name, safe='')}", stream=True, exceptions_for_status={404: FileNotFoundError(file_name)}, ) response.close() size_str = response.headers.get(self.HEADER_CONTENT_LENGTH) size = int(size_str) if size_str else None digest = response.headers.get(self.HEADER_DIGEST) return FileInfo( path=file_name, digest=None if digest is None else Digest.from_encoded(digest), updated=datetime.strptime( response.headers[self.HEADER_LAST_MODIFIED], "%a, %d %b %Y %H:%M:%S %Z" ), size=size, )
[docs] def delete(self, dataset: Union[str, Dataset]): """ Delete a dataset. :param dataset: The dataset ID, name, or object. :raises DatasetNotFound: If the dataset can't be found. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset_id = self.resolve_dataset(dataset).id self.request( f"datasets/{self.url_quote(dataset_id)}", method="DELETE", exceptions_for_status={404: DatasetNotFound(self._not_found_err_msg(dataset))}, )
[docs] def sync( self, dataset: Union[str, Dataset], *sources: PathOrStr, target: Optional[PathOrStr] = None, quiet: bool = False, max_workers: Optional[int] = None, strip_paths: bool = False, ) -> None: """ Sync local files or directories to an uncommitted dataset. :param dataset: The dataset ID, name, or object. :param sources: Local source files or directories to upload to the dataset. :param target: If specified, all source files/directories will be uploaded under a directory of this name. :param max_workers: The maximum number of thread pool workers to use to upload files concurrently. :param quiet: If ``True``, progress won't be displayed. :param strip_paths: If ``True``, all source files and directories will be uploaded under their name, not their path. E.g. the file "docs/source/index.rst" would be uploaded as just "index.rst", instead of "docs/source/index.rst". .. note:: This only applies to source paths that are children of the current working directory. If a source path is outside of the current working directory, it will always be uploaded under its name only. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetWriteError: If the dataset was already committed. :raises FileNotFoundError: If a source doesn't exist. :raises UnexpectedEOFError: If a source is a directory and the contents of one of the directory's files changes while creating the dataset. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset = self.resolve_dataset(dataset) if dataset.committed is not None: raise DatasetWriteError(dataset.id) from ..progress import get_dataset_sync_progress with get_dataset_sync_progress(quiet) as progress: bytes_task = progress.add_task("Uploading dataset") total_bytes = 0 # map source path to (target_path, size) path_info: Dict[Path, Tuple[Path, int]] = {} for source in sources: source = Path(source) strip_path = strip_paths or not path_is_relative_to(source, ".") if source.is_file(): target_path = Path(source.name) if strip_path else source if target is not None: target_path = Path(str(target)) / target_path size = source.lstat().st_size path_info[source] = (target_path, size) total_bytes += size elif source.is_dir(): for path in source.glob("**/*"): if path.is_dir(): continue target_path = path.relative_to(source) if strip_path else path if target is not None: target_path = Path(str(target)) / target_path size = path.lstat().st_size if size == 0: continue path_info[path] = (target_path, size) total_bytes += size else: raise FileNotFoundError(source) import concurrent.futures progress.update(bytes_task, total=total_bytes) # Now upload. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Dispatch tasks to thread pool executor. future_to_path = {} for path, (target_path, size) in path_info.items(): future = executor.submit( self._upload_file, dataset, size, path, target_path, progress, bytes_task, ignore_errors=True, ) future_to_path[future] = path # Collect completed tasks. for future in concurrent.futures.as_completed(future_to_path): path = future_to_path[future] original_size = path_info[path][1] actual_size = future.result() if actual_size != original_size: # If the size of the file has changed since we started, adjust total. total_bytes += actual_size - original_size progress.update(bytes_task, total=total_bytes)
[docs] def upload( self, dataset: Union[str, Dataset], source: bytes, target: PathOrStr, quiet: bool = False, ) -> None: """ Upload raw bytes to an uncommitted dataset. :param dataset: The dataset ID, name, or object. :param source: The raw bytes to upload to the dataset. :param target: The name to assign to the file for the bytes in the dataset. :param quiet: If ``True``, progress won't be displayed. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetWriteError: If the dataset was already committed. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset = self.resolve_dataset(dataset) if dataset.committed is not None: raise DatasetWriteError(dataset.id) from ..progress import get_dataset_sync_progress size = len(source) with get_dataset_sync_progress(quiet) as progress: task_id = progress.add_task("Uploading source") if size is not None: progress.update(task_id, total=size) self._upload_file(dataset, size, source, target, progress, task_id)
[docs] def ls(self, dataset: Union[str, Dataset], prefix: Optional[str] = None) -> List[FileInfo]: """ List files in a dataset. :param dataset: The dataset ID, name, or object. :param prefix: An optional path prefix to filter by. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetReadError: If the :data:`~beaker.data_model.dataset.Dataset.storage` hasn't been set. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset = self.resolve_dataset(dataset) query = {} if prefix is None else {"prefix": prefix} info = DatasetInfo.from_json( self.request( f"datasets/{dataset.id}/files", query=query, exceptions_for_status={404: DatasetNotFound(self._not_found_err_msg(dataset.id))}, ).json() ) return list(info.page.data)
[docs] def size(self, dataset: Union[str, Dataset]) -> int: """ Calculate the size of a dataset, in bytes. :param dataset: The dataset ID, name, or object. :raises DatasetNotFound: If the dataset can't be found. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ dataset = self.resolve_dataset(dataset) info = DatasetInfo.from_json( self.request( f"datasets/{dataset.id}/files", exceptions_for_status={404: DatasetNotFound(self._not_found_err_msg(dataset.id))}, ).json() ) return info.size.bytes
[docs] def rename(self, dataset: Union[str, Dataset], name: str) -> Dataset: """ Rename a dataset. :param dataset: The dataset ID, name, or object. :param name: The new name of the dataset. :raises ValueError: If the new name is invalid. :raises DatasetNotFound: If the dataset can't be found. :raises DatasetConflict: If a dataset by that name already exists. :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. :raises RequestException: Any other exception that can occur when contacting the Beaker server. """ self.validate_beaker_name(name) dataset_id = self.resolve_dataset(dataset).id return Dataset.from_json( self.request( f"datasets/{self.url_quote(dataset_id)}", method="PATCH", data=DatasetPatch(name=name), exceptions_for_status={ 409: DatasetConflict(name), 404: DatasetNotFound(dataset_id), }, ).json() )
[docs] def url(self, dataset: Union[str, Dataset]) -> str: """ Get the URL for a dataset. :param dataset: The dataset ID, name, or object. :raises DatasetNotFound: If the dataset can't be found. """ dataset_id = self.resolve_dataset(dataset).id return f"{self.config.agent_address}/ds/{self.url_quote(dataset_id)}"
def _not_found_err_msg(self, dataset: Union[str, Dataset]) -> str: dataset = dataset if isinstance(dataset, str) else dataset.id return ( f"'{dataset}': Make sure you're using a valid Beaker dataset ID or the " f"*full* name of the dataset (with the account prefix, e.g. 'username/dataset_name')" ) def _upload_file( self, dataset: Dataset, size: int, source: Union[PathOrStr, bytes], target: PathOrStr, progress: "Progress", task_id: "TaskID", ignore_errors: bool = False, ) -> int: from ..progress import BufferedReaderWithProgress assert dataset.storage is not None if dataset.storage.scheme != "fh": raise NotImplementedError( f"Datasets API is not implemented for '{dataset.storage.scheme}' backend yet" ) source_file_wrapper: BufferedReaderWithProgress if isinstance(source, (str, Path, os.PathLike)): source = Path(source) if ignore_errors and not source.exists(): return 0 source_file_wrapper = BufferedReaderWithProgress(source.open("rb"), progress, task_id) elif isinstance(source, bytes): source_file_wrapper = BufferedReaderWithProgress(io.BytesIO(source), progress, task_id) else: raise ValueError(f"Expected path-like or raw bytes, got {type(source)}") try: body: Optional[BufferedReaderWithProgress] = source_file_wrapper digest: Optional[str] = None if size > self.REQUEST_SIZE_LIMIT: @retriable() def get_upload_id() -> str: assert dataset.storage is not None # for mypy response = self.request( "uploads", method="POST", token=dataset.storage.token, base_url=dataset.storage.base_url, ) return response.headers[self.HEADER_UPLOAD_ID] upload_id = get_upload_id() written = 0 while written < size: chunk = source_file_wrapper.read(self.REQUEST_SIZE_LIMIT) if not chunk: break @retriable() def upload() -> "Response": assert dataset.storage is not None # for mypy return self.request( f"uploads/{upload_id}", method="PATCH", data=chunk, token=dataset.storage.token, base_url=dataset.storage.base_url, headers={ self.HEADER_UPLOAD_LENGTH: str(size), self.HEADER_UPLOAD_OFFSET: str(written), }, ) response = upload() written += len(chunk) digest = response.headers.get(self.HEADER_DIGEST) if digest: break if written != size: raise UnexpectedEOFError(str(source)) body = None @retriable() def finalize(): assert dataset.storage is not None # for mypy self.request( f"datasets/{dataset.storage.id}/files/{str(target)}", method="PUT", data=body if size > 0 else b"", token=dataset.storage.token, base_url=dataset.storage.base_url, headers=None if not digest else {self.HEADER_DIGEST: digest}, stream=body is not None and size > 0, exceptions_for_status={ 403: DatasetWriteError(dataset.id), 404: DatasetNotFound(self._not_found_err_msg(dataset.id)), }, ) finalize() return source_file_wrapper.total_read finally: source_file_wrapper.close() def _stream_file( self, dataset: Dataset, file: FileInfo, chunk_size: Optional[int] = None, offset: int = 0, length: int = -1, validate_checksum: bool = True, ) -> Generator[bytes, None, None]: def stream_file() -> Generator[bytes, None, None]: headers = {} if offset > 0 and length > 0: headers["Range"] = f"bytes={offset}-{offset + length - 1}" elif offset > 0: headers["Range"] = f"bytes={offset}-" response = self.request( f"datasets/{dataset.id}/files/{urllib.parse.quote(file.path, safe='')}", method="GET", stream=True, headers=headers, exceptions_for_status={404: FileNotFoundError(file.path)}, ) for chunk in response.iter_content(chunk_size=chunk_size or self.DOWNLOAD_CHUNK_SIZE): yield chunk if is_canceled is not None and is_canceled.is_set(): # type: ignore raise ThreadCanceledError contents_hash = None if offset == 0 and validate_checksum and file.digest is not None: contents_hash = file.digest.new_hasher() retries = 0 while True: try: for chunk in stream_file(): if is_canceled is not None and is_canceled.is_set(): # type: ignore raise ThreadCanceledError offset += len(chunk) if contents_hash is not None: contents_hash.update(chunk) yield chunk break except RequestException as err: if retries < self.beaker.MAX_RETRIES: log_and_wait(retries, err) retries += 1 else: raise # Validate digest. if file.digest is not None and contents_hash is not None: actual_digest = Digest.from_decoded( contents_hash.digest(), algorithm=file.digest.algorithm ) if actual_digest != file.digest: raise ChecksumFailedError( f"Checksum for '{file.path}' failed. " f"Expected '{file.digest}', got '{actual_digest}'." ) def _download_file( self, dataset: Dataset, file: FileInfo, target_path: Path, progress: Optional["Progress"] = None, task_id: Optional["TaskID"] = None, validate_checksum: bool = True, chunk_size: Optional[int] = None, ) -> int: import tempfile total_bytes = 0 target_dir = target_path.parent target_dir.mkdir(exist_ok=True, parents=True) def on_failure(): if progress is not None and task_id is not None: progress.advance(task_id, -total_bytes) @retriable( on_failure=on_failure, recoverable_errors=(RequestException, ChecksumFailedError), ) def download() -> int: nonlocal total_bytes tmp_target = tempfile.NamedTemporaryFile( "w+b", dir=target_dir, delete=False, suffix=".tmp" ) try: for chunk in self._stream_file( dataset, file, validate_checksum=validate_checksum, chunk_size=chunk_size, ): total_bytes += len(chunk) tmp_target.write(chunk) if progress is not None and task_id is not None: progress.update(task_id, advance=len(chunk)) os.replace(tmp_target.name, target_path) finally: tmp_target.close() if os.path.exists(tmp_target.name): os.remove(tmp_target.name) return total_bytes return download()