Source code for oml.utils.download_mock_dataset

from pathlib import Path
from typing import Tuple, Union

import gdown
import pandas as pd

from oml.const import (
    CATEGORIES_COLUMN,
    IS_GALLERY_COLUMN,
    IS_QUERY_COLUMN,
    LABELS_COLUMN,
    MOCK_AUDIO_DATASET_MD5,
    MOCK_AUDIO_DATASET_PATH,
    MOCK_AUDIO_DATASET_URL_GDRIVE,
    MOCK_DATASET_DEFAULT_CSV,
    MOCK_DATASET_MD5,
    MOCK_DATASET_PATH,
    MOCK_DATASET_URL_GDRIVE,
    SPLIT_COLUMN,
    TEXTS_COLUMN,
)
from oml.utils.io import check_exists_and_validate_md5
from oml.utils.remote_storage import download_folder_from_remote_storage


def _get_mock_dataset(
    dataset_local_folder: Union[str, Path],
    dataset_remote_folder: str,
    dataset_md5: str,
    dataset_gdrive_url: str,
    df_name: str,
    check_md5: bool = True,
    global_paths: bool = False,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Downloads and prepares a mock dataset in the required format.

    Args:
        dataset_local_folder: The directory where the dataset will be saved.
        dataset_remote_folder: The remote directory on `oml.daloroserver.com` from which the dataset will be downloaded.
        dataset_md5: The MD5 checksum used to validate the dataset.
        dataset_gdrive_url: The Google Drive URL for the dataset download as a fallback option.
        df_name: The name of the CSV file from which the output DataFrames will be generated.
        check_md5: If ``True``, validates the dataset using an MD5 checksum.
        global_paths: If ``True``, concatenates the paths in the dataset with the dataset_local_folder.

    Returns:
        A tuple containing two DataFrames:
            - The first DataFrame is for the training stage.
            - The second DataFrame is for the validation stage.

    Raises:
        Exception: If the downloaded dataset is invalid.
    """
    dataset_local_folder = Path(dataset_local_folder)
    dataset_md5 = dataset_md5 if check_md5 else None

    if not check_exists_and_validate_md5(dataset_local_folder, dataset_md5):
        try:
            print("Downloading from oml.daloroserver.com")
            download_folder_from_remote_storage(
                remote_folder=dataset_remote_folder, local_folder=str(dataset_local_folder)
            )
        except Exception:
            print("We could not download from oml.daloroserver.com, let's try Google Drive.")
            gdown.download_folder(url=dataset_gdrive_url, output=str(dataset_local_folder))
    else:
        print(f"Mock dataset has been downloaded already to {dataset_local_folder}")

    if not check_exists_and_validate_md5(dataset_local_folder, dataset_md5):
        raise Exception("Downloaded mock dataset is invalid.")

    df = pd.read_csv(dataset_local_folder / df_name)

    if global_paths:
        df["path"] = df["path"].apply(lambda x: str(dataset_local_folder / x))

    df_train = df[df["split"] == "train"].reset_index(drop=True)
    df_val = df[df["split"] == "validation"].reset_index(drop=True)

    df_val = df_val.astype({"is_query": bool, "is_gallery": bool})

    return df_train, df_val


[docs]def get_mock_images_dataset( dataset_root: Union[str, Path] = MOCK_DATASET_PATH, df_name: str = MOCK_DATASET_DEFAULT_CSV, check_md5: bool = True, global_paths: bool = False, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Function to download mock images dataset which is already prepared in the required format. Args: dataset_root: The directory where the dataset will be saved. df_name: The name of the CSV file from which the output DataFrames will be generated. check_md5: If ``True``, validates the dataset using an MD5 checksum. global_paths: If ``True``, concatenates the paths in the dataset with the dataset_local_folder. Returns: A tuple containing two DataFrames: - The first DataFrame is for the training stage. - The second DataFrame is for the validation stage. """ return _get_mock_dataset( dataset_local_folder=dataset_root, dataset_remote_folder=MOCK_DATASET_PATH.name, dataset_md5=MOCK_DATASET_MD5, dataset_gdrive_url=MOCK_DATASET_URL_GDRIVE, df_name=df_name, check_md5=check_md5, global_paths=global_paths, )
def download_mock_dataset( dataset_root: Union[str, Path] = MOCK_DATASET_PATH, check_md5: bool = True, df_name: str = "df.csv", global_paths: bool = False, ) -> Tuple[pd.DataFrame, pd.DataFrame]: # for back compatibility return get_mock_images_dataset( dataset_root=dataset_root, check_md5=check_md5, df_name=df_name, global_paths=global_paths )
[docs]def get_mock_audios_dataset( dataset_root: Union[str, Path] = MOCK_AUDIO_DATASET_PATH, df_name: str = MOCK_DATASET_DEFAULT_CSV, check_md5: bool = True, global_paths: bool = False, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Function to download mock audios dataset which is already prepared in the required format. Args: dataset_root: The directory where the dataset will be saved. df_name: The name of the CSV file from which the output DataFrames will be generated. check_md5: If ``True``, validates the dataset using an MD5 checksum. global_paths: If ``True``, concatenates the paths in the dataset with the dataset_local_folder. Returns: A tuple containing two DataFrames: - The first DataFrame is for the training stage. - The second DataFrame is for the validation stage. """ return _get_mock_dataset( dataset_local_folder=dataset_root, dataset_remote_folder=MOCK_AUDIO_DATASET_PATH.name, dataset_md5=MOCK_AUDIO_DATASET_MD5, dataset_gdrive_url=MOCK_AUDIO_DATASET_URL_GDRIVE, df_name=df_name, check_md5=check_md5, global_paths=global_paths, )
[docs]def get_mock_texts_dataset() -> Tuple[pd.DataFrame, pd.DataFrame]: """ Mock texts dataset useful for prototyping pipelines and understanding dataset structure. """ bed = [ "Luxura King Bed: Plush headboard and durable frame for ultimate comfort.", "EcoSleep Twin Bed: Sustainable materials for eco-conscious consumers.", "DreamRest Queen Bed: Ergonomic design for restful sleep.", "ClassicWood Full Bed: Sturdy wood construction with a sleek finish.", ] table = [ "UrbanChic Dining Table: Glass top with metal legs seats six.", "RusticFarmhouse Coffee Table: Reclaimed wood for rustic charm.", "Minimalist Work Desk: Sleek, modern workspace.", "ArtisanCraft End Table: Handcrafted accent piece.", "Vintage Bistro Table: Classic design for cozy corners.", "Compact Folding Table: Versatile and easy to store.", ] tv = [ 'UltraHD Smart TV 55": Stunning visuals and smart features.', 'Compact LED TV 32": Crisp picture in a space-saving design.', 'Curved 4K TV 65": Panoramic viewing with brilliant colors.', 'BudgetFriendly LCD TV 40": Clear picture and essential features.', ] potato = [ "Yukon Gold Potato: Creamy and golden, perfect for roasting and mashing.", "Russet Potato: Classic and versatile, ideal for baking and frying.", "Red Potato: Smooth and firm, great for boiling and salads.", "Fingerling Potato: Rich and nutty, perfect for roasting or grilling.", "Purple Potato: Vibrant and sweet, adds color to any dish.", ] chair = [ "ErgoComfort Office Chair: Maximum comfort and support.", "ClassicWood Dining Chair: Timeless design with sturdy construction.", "RelaxLounge Recliner: Adjustable settings and plush cushioning.", "ModernAccent Chair: Contemporary style with vibrant color options.", ] phone = [ "Galaxy X9: Sleek design with powerful performance.", "iPhone 12: Cutting-edge technology with a stylish look.", "Pixel 5: Pure Android experience with excellent camera.", "OnePlus 8T: High performance at a competitive price.", "Moto G Power: Long battery life for extended use.", "Sony Xperia 5: Superior camera and display quality.", ] audio = [ "NoiseCancelling Headphones Pro: Superior sound and noise cancellation.", "Wireless Earbuds Sport: Great sound with a secure fit.", "StudioOverEar Headphones: Exceptional audio clarity.", "BudgetFriendly Wired Headphones: Good sound at an affordable price.", "Gaming Headset Pro: Surround sound for gaming.", "TravelNoiseCancelling Earbuds: Compact with excellent noise cancellation.", ] texts = [*bed, *table, *tv, *potato, *chair, *phone, *audio] labels = ( [0] * len(bed) + [1] * len(table) + [2] * len(tv) + [3] * len(potato) + [4] * len(chair) + [5] * len(phone) + [6] * len(audio) ) categories = ( ["furniture"] * len(bed) + ["furniture"] * len(table) + ["electronic"] * len(tv) + ["food"] * len(potato) + ["furniture"] * len(chair) + ["electronic"] * len(phone) + ["electronic"] * len(audio) ) n_train = len(bed) + len(table) + len(tv) + len(potato) n_val = len(chair) + len(phone) + len(audio) split = ["train"] * n_train + ["validation"] * n_val is_query = [None] * n_train + [True] * n_val # type: ignore is_gallery = is_query data = { TEXTS_COLUMN: texts, LABELS_COLUMN: labels, CATEGORIES_COLUMN: categories, SPLIT_COLUMN: split, IS_QUERY_COLUMN: is_query, IS_GALLERY_COLUMN: is_gallery, } df = pd.DataFrame(data) df_train = df[df["split"] == "train"].reset_index(drop=True) df_val = df[df["split"] == "validation"].reset_index(drop=True) return df_train, df_val
__all__ = ["get_mock_images_dataset", "get_mock_texts_dataset", "get_mock_audios_dataset"]