import re from collections import defaultdict from datetime import datetime from functools import lru_cache from typing import List import gradio as gr from huggingface_hub import HfApi from get_dataset_stats import get_dataset_stats COLLECTION_SLUG = "allenai/molmoact2-bimanualyam-dataset" COLLECTION_URL = f"https://huggingface.co/collections/{COLLECTION_SLUG}" @lru_cache(maxsize=1) def get_collection_datasets() -> List[str]: """Return public dataset repos from the MolmoAct2-BimanualYAM collection.""" api = HfApi() collection = api.get_collection(COLLECTION_SLUG, token=False) dataset_ids = [item.item_id for item in collection.items if item.item_type == "dataset"] seen = set() unique = [] for repo_id in dataset_ids: if repo_id not in seen: unique.append(repo_id) seen.add(repo_id) return unique @lru_cache(maxsize=2048) def get_cached_dataset_stats(repo_id: str): """Cache per-repo stats in the Space process to make repeated UI use cheaper.""" return get_dataset_stats(repo_id, hf_token=None) def get_allowed_dataset_set() -> set[str]: return set(get_collection_datasets()) def format_duration(seconds): """Format duration as hours, minutes, seconds.""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) if hours > 0: return f"{hours}h {minutes}m {secs}s" if minutes > 0: return f"{minutes}m {secs}s" return f"{secs}s" def fetch_stats_for_selected(selected_datasets: List[str], progress=gr.Progress()): """Fetch statistics for selected collection datasets.""" if not selected_datasets: return "Please select at least one dataset." allowed = get_allowed_dataset_set() selected_datasets = list(dict.fromkeys(selected_datasets)) outside_collection = [repo_id for repo_id in selected_datasets if repo_id not in allowed] if outside_collection: blocked = "\n".join(f"- `{repo_id}`" for repo_id in outside_collection[:20]) extra = "\n- ..." if len(outside_collection) > 20 else "" return ( "Selection contains repos outside the public " f"[{COLLECTION_SLUG}]({COLLECTION_URL}) collection.\n\n" f"{blocked}{extra}" ) total_episodes = 0 v3_by_date = defaultdict(list) non_v3_results = [] errors = [] for i, repo_id in enumerate(selected_datasets): try: progress((i + 1) / len(selected_datasets), desc=f"Processing {repo_id}...") stats = get_cached_dataset_stats(repo_id) if stats.get("error"): errors.append(f"Error for `{repo_id}`: {stats['error']}") continue episodes = stats["total_episodes"] total_episodes += episodes is_v3 = stats.get("format_version") == "v3.0" if is_v3: date_match = re.search(r"/(\d{8})", repo_id) if date_match: date_str = date_match.group(1) try: date_obj = datetime.strptime(date_str, "%d%m%Y") date_key = date_obj.strftime("%Y-%m-%d") date_display = date_obj.strftime("%B %d, %Y") except ValueError: date_key = date_str date_display = date_str else: date_key = "unknown" date_display = "Unknown Date" v3_by_date[date_key].append( { "repo_id": repo_id, "episodes": episodes, "date_display": date_display, "stats": stats, } ) else: non_v3_results.append( { "repo_id": repo_id, "episodes": episodes, "stats": stats, } ) except Exception as e: errors.append(f"Error for `{repo_id}`: {e}") total_duration_seconds = 0 for datasets in v3_by_date.values(): for dataset in datasets: info_meta = dataset["stats"].get("info_metadata", {}) or {} if info_meta.get("total_frames"): fps = info_meta.get("fps", 30) total_duration_seconds += info_meta["total_frames"] / fps for dataset in non_v3_results: info_meta = dataset["stats"].get("info_metadata", {}) or {} if info_meta.get("total_frames"): fps = info_meta.get("fps", 30) total_duration_seconds += info_meta["total_frames"] / fps duration_display = f" • {format_duration(total_duration_seconds)}" if total_duration_seconds > 0 else "" output = [ f"## Total Episodes: {total_episodes}{duration_display}", f"Selected datasets: **{len(selected_datasets)}** / **{len(allowed)}**", ] if v3_by_date: sorted_dates = sorted([key for key in v3_by_date if key != "unknown"], reverse=True) if "unknown" in v3_by_date: sorted_dates.append("unknown") for date_key in sorted_dates: datasets = v3_by_date[date_key] date_display = datasets[0]["date_display"] date_total_episodes = sum(dataset["episodes"] for dataset in datasets) date_total_seconds = 0 for dataset in datasets: info_meta = dataset["stats"].get("info_metadata", {}) or {} if info_meta.get("total_frames"): fps = info_meta.get("fps", 30) date_total_seconds += info_meta["total_frames"] / fps output.append( f"\n**{date_display}** — Total: **{date_total_episodes} episodes**" f" • {format_duration(date_total_seconds)}" ) for dataset in sorted(datasets, key=lambda item: item["repo_id"]): repo_name = dataset["repo_id"].split("/")[-1] episodes = dataset["episodes"] info_meta = dataset["stats"].get("info_metadata", {}) or {} duration_str = "" if info_meta.get("total_frames"): fps = info_meta.get("fps", 30) duration_str = f" • {format_duration(info_meta['total_frames'] / fps)}" output.append(f"- `{repo_name}`: **{episodes} episodes**{duration_str}") if non_v3_results: output.append("\n### Other Formats") for dataset in non_v3_results: info_meta = dataset["stats"].get("info_metadata", {}) or {} duration_str = "" if info_meta.get("total_frames"): fps = info_meta.get("fps", 30) duration_str = f" • {format_duration(info_meta['total_frames'] / fps)}" output.append(f"- `{dataset['repo_id']}`: **{dataset['episodes']} episodes**{duration_str}") if errors: output.append("\n### Errors") output.extend(f"- {error}" for error in errors) return "\n".join(output) def load_collection_datasets(): get_collection_datasets.cache_clear() datasets = get_collection_datasets() return [ gr.update(choices=datasets, value=[]), datasets, f"Loaded **{len(datasets)}** datasets from [{COLLECTION_SLUG}]({COLLECTION_URL}).", ] def select_matching(filter_text: str, choices: List[str]): choices = choices or [] query = (filter_text or "").strip().lower() if not query: return gr.update(value=choices) return gr.update(value=[repo_id for repo_id in choices if query in repo_id.lower()]) _initial_datasets = get_collection_datasets() with gr.Blocks(title="MolmoAct2-BimanualYAM Dataset Stats") as demo: gr.Markdown( "# MolmoAct2-BimanualYAM Dataset Stats\n" f"Public stats viewer for datasets in [{COLLECTION_SLUG}]({COLLECTION_URL})." ) current_choices = gr.State(_initial_datasets) collection_status = gr.Markdown( f"Loaded **{len(_initial_datasets)}** datasets from [{COLLECTION_SLUG}]({COLLECTION_URL})." ) with gr.Row(): refresh_btn = gr.Button("Refresh Collection", variant="secondary") filter_box = gr.Textbox(label="Filter", placeholder="Example: tablebuss, scan, 02012026") dataset_checkboxes = gr.CheckboxGroup( label="Select Datasets", choices=_initial_datasets, interactive=True, ) with gr.Row(): select_all_btn = gr.Button("Select All", size="sm") select_matching_btn = gr.Button("Select Matching", size="sm") deselect_all_btn = gr.Button("Deselect All", size="sm") fetch_btn = gr.Button("Fetch Statistics", variant="primary") stats_output = gr.Markdown( label="Dataset Statistics", value="Select datasets and click **Fetch Statistics**.", ) refresh_btn.click( load_collection_datasets, outputs=[dataset_checkboxes, current_choices, collection_status], ) select_all_btn.click( lambda choices: gr.update(value=choices), inputs=current_choices, outputs=dataset_checkboxes, ) select_matching_btn.click( select_matching, inputs=[filter_box, current_choices], outputs=dataset_checkboxes, ) deselect_all_btn.click( lambda: gr.update(value=[]), outputs=dataset_checkboxes, ) fetch_btn.click( fetch_stats_for_selected, inputs=dataset_checkboxes, outputs=stats_output, ) if __name__ == "__main__": demo.launch()