Skip to content

API Reference

Auto-generated documentation from Python docstrings using mkdocstrings.

Note

Some modules may have limited docstring coverage. This reference will improve over time.


Project Management

castle.utils.project_manager

Project management utilities for Castle AI.

create_project(storage_path, project_name)

Create a new project with config file.

Parameters:

Name Type Description Default
storage_path

Path to the storage directory

required
project_name

Name of the new project

required

Raises:

Type Description
FileExistsError

If project already exists

Source code in castle/utils/project_manager.py
def create_project(storage_path, project_name):
    """Create a new project with config file.

    Args:
        storage_path: Path to the storage directory
        project_name: Name of the new project

    Raises:
        FileExistsError: If project already exists
    """
    project_path = os.path.join(storage_path, project_name)

    if os.path.exists(project_path):
        raise FileExistsError(f"Project '{project_name}' already exists")

    os.makedirs(project_path)

    config_path = os.path.join(project_path, 'config.json')
    config = {
        'project_name': project_name,
        'created_at': datetime.datetime.now().isoformat()
    }

    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)

delete_project(storage_path, project_name)

Delete a project directory.

Parameters:

Name Type Description Default
storage_path

Path to the storage directory

required
project_name

Name of the project to delete

required

Returns:

Name Type Description
bool

True if deletion was successful, False otherwise

Source code in castle/utils/project_manager.py
def delete_project(storage_path, project_name):
    """Delete a project directory.

    Args:
        storage_path: Path to the storage directory
        project_name: Name of the project to delete

    Returns:
        bool: True if deletion was successful, False otherwise
    """
    project_path = os.path.join(storage_path, project_name)

    if os.path.exists(project_path):
        shutil.rmtree(project_path)
        logger.info(f"Deleted project: {project_path}")
        return True
    else:
        logger.warning(f"Project not found: {project_path}")
        return False

generate_default_project_name(custom_name='')

Generate a default project name with timestamp.

Parameters:

Name Type Description Default
custom_name

Optional custom name to use

''

Returns:

Name Type Description
str

Project name

Source code in castle/utils/project_manager.py
def generate_default_project_name(custom_name=''):
    """Generate a default project name with timestamp.

    Args:
        custom_name: Optional custom name to use

    Returns:
        str: Project name
    """
    if custom_name and len(custom_name.strip()) > 0:
        return custom_name.strip()

    timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    return f'{timestamp}-Project'

initialize_storage(root_path)

Initialize storage directory if it doesn't exist.

Parameters:

Name Type Description Default
root_path

Path to the root storage directory

required

Returns:

Name Type Description
str

Normalized storage path with trailing separator

Source code in castle/utils/project_manager.py
def initialize_storage(root_path):
    """Initialize storage directory if it doesn't exist.

    Args:
        root_path: Path to the root storage directory

    Returns:
        str: Normalized storage path with trailing separator
    """
    if root_path is None or root_path == '':
        root_path = os.path.join('projects', '')

    # Ensure path ends with separator
    if not root_path.endswith(os.sep):
        root_path += os.sep

    # Create directory if it doesn't exist
    if not os.path.exists(root_path):
        os.makedirs(root_path)

    return root_path

list_projects(storage_path)

List all project directories in the storage path.

Parameters:

Name Type Description Default
storage_path

Path to the storage directory

required

Returns:

Type Description

List of project directory names, sorted if more than one exists

Source code in castle/utils/project_manager.py
def list_projects(storage_path):
    """List all project directories in the storage path.

    Args:
        storage_path: Path to the storage directory

    Returns:
        List of project directory names, sorted if more than one exists
    """
    if not os.path.exists(storage_path):
        return []

    projects = [
        d for d in os.listdir(storage_path) 
        if os.path.isdir(os.path.join(storage_path, d))
    ]

    if len(projects) > 1:
        return sorted(projects, reverse=True)
    return projects

Video I/O

castle.utils.video_io

影片輸入/輸出工具模組

該模組提供了影片讀取、寫入和字幕生成的功能。 相較於使用 OpenCV,新版本使用 av 庫提供更高效的影片處理能力。

主要功能: - 高效的影片讀取和隨機存取 - 影片編碼和寫入 - 字幕檔案生成 (SRT, WebVTT) - 自動資源管理和錯誤處理

SubtitleGenerator

字幕生成器

提供字幕檔案的建立和儲存功能,支援多種字幕格式。 可用於為影片自動生成時間同步的字幕檔案。

支援格式: - SRT (SubRip Text) - WebVTT (Web Video Text Tracks)

範例用法: generator = SubtitleGenerator() generator.add_subtitle(0.0, 5.0, "第一段字幕") generator.add_subtitle(5.0, 10.0, "第二段字幕") generator.save("output.srt", format="srt")

Source code in castle/utils/video_io.py
class SubtitleGenerator:
    """
    字幕生成器

    提供字幕檔案的建立和儲存功能,支援多種字幕格式。
    可用於為影片自動生成時間同步的字幕檔案。

    支援格式:
    - SRT (SubRip Text)
    - WebVTT (Web Video Text Tracks)

    範例用法:
        generator = SubtitleGenerator()
        generator.add_subtitle(0.0, 5.0, "第一段字幕")
        generator.add_subtitle(5.0, 10.0, "第二段字幕")
        generator.save("output.srt", format="srt")
    """

    def __init__(self):
        """初始化字幕生成器"""
        self.subtitles = []
        logger.debug("字幕生成器初始化完成")

    def add_subtitle(
        self,
        start_time: float,
        end_time: float,
        text: str
    ) -> None:
        """
        添加字幕項目

        Args:
            start_time: 開始時間(秒)
            end_time: 結束時間(秒)
            text: 字幕文字內容

        Raises:
            ValueError: 當時間參數無效時
        """
        if start_time < 0 or end_time < 0:
            raise ValueError("時間不能為負數")

        if start_time >= end_time:
            raise ValueError("開始時間必須小於結束時間")

        if not text.strip():
            raise ValueError("字幕文字不能為空")

        self.subtitles.append({
            'start': start_time,
            'end': end_time,
            'text': text.strip()
        })

        logger.debug(f"添加字幕: {start_time:.1f}s-{end_time:.1f}s: {text[:30]}...")

    def save(self, output_path: Union[str, Path], format: str = 'srt') -> None:
        """
        儲存字幕檔案

        Args:
            output_path: 輸出檔案路徑
            format: 字幕格式 ('srt', 'vtt')

        Raises:
            ValueError: 當格式不支援時
            IOError: 當檔案寫入失敗時
        """
        if not self.subtitles:
            logger.warning("沒有字幕資料可儲存")
            return

        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        try:
            if format.lower() == 'srt':
                self._save_srt(output_path)
            elif format.lower() == 'vtt':
                self._save_vtt(output_path)
            else:
                raise ValueError(f"不支援的字幕格式: {format}")

            logger.info(f"字幕檔案儲存完成: {output_path} ({len(self.subtitles)} 個項目)")

        except Exception as e:
            raise IOError(f"儲存字幕檔案失敗: {e}")

    def _save_srt(self, output_path: Path) -> None:
        """儲存為 SRT 格式"""
        with open(output_path, 'w', encoding='utf-8') as f:
            for i, sub in enumerate(self.subtitles, 1):
                f.write(f"{i}\n")
                f.write(f"{self._format_srt_time(sub['start'])} --> ")
                f.write(f"{self._format_srt_time(sub['end'])}\n")
                f.write(f"{sub['text']}\n\n")

    def _save_vtt(self, output_path: Path) -> None:
        """儲存為 WebVTT 格式"""
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write("WEBVTT\n\n")

            for sub in self.subtitles:
                f.write(f"{self._format_vtt_time(sub['start'])} --> ")
                f.write(f"{self._format_vtt_time(sub['end'])}\n")
                f.write(f"{sub['text']}\n\n")

    @staticmethod
    def _format_srt_time(seconds: float) -> str:
        """
        格式化 SRT 時間格式

        Args:
            seconds: 時間(秒)

        Returns:
            SRT 格式時間字串 (HH:MM:SS,mmm)
        """
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = seconds % 60

        return f"{hours:02d}:{minutes:02d}:{secs:06.3f}".replace('.', ',')

    @staticmethod
    def _format_vtt_time(seconds: float) -> str:
        """
        格式化 WebVTT 時間格式

        Args:
            seconds: 時間(秒)

        Returns:
            WebVTT 格式時間字串 (HH:MM:SS.mmm)
        """
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = seconds % 60

        return f"{hours:02d}:{minutes:02d}:{secs:06.3f}"

    def clear(self) -> None:
        """清除所有字幕項目"""
        self.subtitles.clear()
        logger.debug("字幕項目已清除")

    def get_subtitle_count(self) -> int:
        """獲取字幕項目數量"""
        return len(self.subtitles)

    def get_total_duration(self) -> float:
        """
        獲取字幕總時長

        Returns:
            總時長(秒),如果沒有字幕則返回 0
        """
        if not self.subtitles:
            return 0.0

        return max(sub['end'] for sub in self.subtitles)

__init__()

初始化字幕生成器

Source code in castle/utils/video_io.py
def __init__(self):
    """初始化字幕生成器"""
    self.subtitles = []
    logger.debug("字幕生成器初始化完成")

add_subtitle(start_time, end_time, text)

添加字幕項目

Parameters:

Name Type Description Default
start_time float

開始時間(秒)

required
end_time float

結束時間(秒)

required
text str

字幕文字內容

required

Raises:

Type Description
ValueError

當時間參數無效時

Source code in castle/utils/video_io.py
def add_subtitle(
    self,
    start_time: float,
    end_time: float,
    text: str
) -> None:
    """
    添加字幕項目

    Args:
        start_time: 開始時間(秒)
        end_time: 結束時間(秒)
        text: 字幕文字內容

    Raises:
        ValueError: 當時間參數無效時
    """
    if start_time < 0 or end_time < 0:
        raise ValueError("時間不能為負數")

    if start_time >= end_time:
        raise ValueError("開始時間必須小於結束時間")

    if not text.strip():
        raise ValueError("字幕文字不能為空")

    self.subtitles.append({
        'start': start_time,
        'end': end_time,
        'text': text.strip()
    })

    logger.debug(f"添加字幕: {start_time:.1f}s-{end_time:.1f}s: {text[:30]}...")

clear()

清除所有字幕項目

Source code in castle/utils/video_io.py
def clear(self) -> None:
    """清除所有字幕項目"""
    self.subtitles.clear()
    logger.debug("字幕項目已清除")

get_subtitle_count()

獲取字幕項目數量

Source code in castle/utils/video_io.py
def get_subtitle_count(self) -> int:
    """獲取字幕項目數量"""
    return len(self.subtitles)

get_total_duration()

獲取字幕總時長

Returns:

Type Description
float

總時長(秒),如果沒有字幕則返回 0

Source code in castle/utils/video_io.py
def get_total_duration(self) -> float:
    """
    獲取字幕總時長

    Returns:
        總時長(秒),如果沒有字幕則返回 0
    """
    if not self.subtitles:
        return 0.0

    return max(sub['end'] for sub in self.subtitles)

save(output_path, format='srt')

儲存字幕檔案

Parameters:

Name Type Description Default
output_path Union[str, Path]

輸出檔案路徑

required
format str

字幕格式 ('srt', 'vtt')

'srt'

Raises:

Type Description
ValueError

當格式不支援時

IOError

當檔案寫入失敗時

Source code in castle/utils/video_io.py
def save(self, output_path: Union[str, Path], format: str = 'srt') -> None:
    """
    儲存字幕檔案

    Args:
        output_path: 輸出檔案路徑
        format: 字幕格式 ('srt', 'vtt')

    Raises:
        ValueError: 當格式不支援時
        IOError: 當檔案寫入失敗時
    """
    if not self.subtitles:
        logger.warning("沒有字幕資料可儲存")
        return

    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    try:
        if format.lower() == 'srt':
            self._save_srt(output_path)
        elif format.lower() == 'vtt':
            self._save_vtt(output_path)
        else:
            raise ValueError(f"不支援的字幕格式: {format}")

        logger.info(f"字幕檔案儲存完成: {output_path} ({len(self.subtitles)} 個項目)")

    except Exception as e:
        raise IOError(f"儲存字幕檔案失敗: {e}")

VideoIO

影片輸入/輸出處理器

提供靜態方法來處理影片的讀取、寫入和基本操作。 使用 av 庫實現高效的影片處理功能。

Source code in castle/utils/video_io.py
class VideoIO:
    """
    影片輸入/輸出處理器

    提供靜態方法來處理影片的讀取、寫入和基本操作。
    使用 av 庫實現高效的影片處理功能。
    """

    @staticmethod
    def load_video(video_path: Union[str, Path]) -> 'VideoReader':
        """
        載入影片檔案

        Args:
            video_path: 影片檔案路徑

        Returns:
            VideoReader 物件

        Raises:
            VideoIOError: 當影片載入失敗時
        """
        try:
            return VideoReader(video_path)
        except Exception as e:
            raise VideoIOError(f"載入影片失敗: {e}")

    @staticmethod
    def save_video(
        frames: List[np.ndarray],
        output_path: Union[str, Path],
        fps: float = 30.0,
        crf: int = 15,
        codec: str = 'libx264'
    ) -> None:
        """
        儲存影片

        Args:
            frames: 影格陣列列表
            output_path: 輸出檔案路徑
            fps: 幀率
            crf: 壓縮品質 (0-51, 數值越小品質越好)
            codec: 編碼器名稱

        Raises:
            VideoIOError: 當影片儲存失敗時
            ValueError: 當輸入資料無效時
        """
        if not frames:
            raise ValueError("沒有影格資料可儲存")

        if not all(isinstance(frame, np.ndarray) for frame in frames):
            raise ValueError("所有影格必須是 numpy 陣列")

        try:
            output_path = Path(output_path)
            output_path.parent.mkdir(parents=True, exist_ok=True)

            # 使用 VideoWriter 類來寫入影片
            with VideoWriter(str(output_path), fps, crf, codec) as writer:
                for frame in frames:
                    writer.write_frame(frame)

            logger.info(f"成功儲存影片到 {output_path}")

        except Exception as e:
            raise VideoIOError(f"儲存影片失敗: {e}")

    @staticmethod
    def get_frame(video: 'VideoReader', frame_idx: int) -> np.ndarray:
        """
        從 VideoReader 獲取指定影格

        Args:
            video: VideoReader 實例
            frame_idx: 影格索引

        Returns:
            影格陣列 (RGB)
        """
        return video.get_frame(frame_idx)

get_frame(video, frame_idx) staticmethod

從 VideoReader 獲取指定影格

Parameters:

Name Type Description Default
video VideoReader

VideoReader 實例

required
frame_idx int

影格索引

required

Returns:

Type Description
ndarray

影格陣列 (RGB)

Source code in castle/utils/video_io.py
@staticmethod
def get_frame(video: 'VideoReader', frame_idx: int) -> np.ndarray:
    """
    從 VideoReader 獲取指定影格

    Args:
        video: VideoReader 實例
        frame_idx: 影格索引

    Returns:
        影格陣列 (RGB)
    """
    return video.get_frame(frame_idx)

load_video(video_path) staticmethod

載入影片檔案

Parameters:

Name Type Description Default
video_path Union[str, Path]

影片檔案路徑

required

Returns:

Type Description
VideoReader

VideoReader 物件

Raises:

Type Description
VideoIOError

當影片載入失敗時

Source code in castle/utils/video_io.py
@staticmethod
def load_video(video_path: Union[str, Path]) -> 'VideoReader':
    """
    載入影片檔案

    Args:
        video_path: 影片檔案路徑

    Returns:
        VideoReader 物件

    Raises:
        VideoIOError: 當影片載入失敗時
    """
    try:
        return VideoReader(video_path)
    except Exception as e:
        raise VideoIOError(f"載入影片失敗: {e}")

save_video(frames, output_path, fps=30.0, crf=15, codec='libx264') staticmethod

儲存影片

Parameters:

Name Type Description Default
frames List[ndarray]

影格陣列列表

required
output_path Union[str, Path]

輸出檔案路徑

required
fps float

幀率

30.0
crf int

壓縮品質 (0-51, 數值越小品質越好)

15
codec str

編碼器名稱

'libx264'

Raises:

Type Description
VideoIOError

當影片儲存失敗時

ValueError

當輸入資料無效時

Source code in castle/utils/video_io.py
@staticmethod
def save_video(
    frames: List[np.ndarray],
    output_path: Union[str, Path],
    fps: float = 30.0,
    crf: int = 15,
    codec: str = 'libx264'
) -> None:
    """
    儲存影片

    Args:
        frames: 影格陣列列表
        output_path: 輸出檔案路徑
        fps: 幀率
        crf: 壓縮品質 (0-51, 數值越小品質越好)
        codec: 編碼器名稱

    Raises:
        VideoIOError: 當影片儲存失敗時
        ValueError: 當輸入資料無效時
    """
    if not frames:
        raise ValueError("沒有影格資料可儲存")

    if not all(isinstance(frame, np.ndarray) for frame in frames):
        raise ValueError("所有影格必須是 numpy 陣列")

    try:
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # 使用 VideoWriter 類來寫入影片
        with VideoWriter(str(output_path), fps, crf, codec) as writer:
            for frame in frames:
                writer.write_frame(frame)

        logger.info(f"成功儲存影片到 {output_path}")

    except Exception as e:
        raise VideoIOError(f"儲存影片失敗: {e}")

VideoIOError

Bases: Exception

影片 I/O 特定的錯誤類別

Source code in castle/utils/video_io.py
class VideoIOError(Exception):
    """影片 I/O 特定的錯誤類別"""
    pass

VideoInfo dataclass

影片資訊資料類別

用於儲存影片的基本屬性資訊,包括路徑、幀率、影格數量、尺寸和時長等。

Source code in castle/utils/video_io.py
@dataclass
class VideoInfo:
    """
    影片資訊資料類別

    用於儲存影片的基本屬性資訊,包括路徑、幀率、影格數量、尺寸和時長等。
    """
    path: Path          # 影片檔案路徑
    fps: float          # 幀率 (frames per second)
    frame_count: int    # 總影格數量
    width: int          # 影片寬度 (像素)
    height: int         # 影片高度 (像素)
    duration: float     # 影片時長 (秒)

VideoReader

影片讀取器

使用 av 庫實現高效的影片讀取功能,支援隨機存取和批次讀取。 相較於傳統的 OpenCV 方法,提供更好的效能和準確性。

主要功能: - 高效的隨機影格存取 - 自動資源管理 - 影格快取機制 - 批次讀取支援

範例用法: # 基本使用 reader = VideoReader("video.mp4") frame = reader.get_frame(100) # 獲取第 100 幀

# 作為 context manager 使用
with VideoReader("video.mp4") as reader:
    for i, frame in reader.iterate_frames(0, 100, 5):
        process_frame(frame)
Source code in castle/utils/video_io.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
class VideoReader:
    """
    影片讀取器

    使用 av 庫實現高效的影片讀取功能,支援隨機存取和批次讀取。
    相較於傳統的 OpenCV 方法,提供更好的效能和準確性。

    主要功能:
    - 高效的隨機影格存取
    - 自動資源管理
    - 影格快取機制
    - 批次讀取支援

    範例用法:
        # 基本使用
        reader = VideoReader("video.mp4")
        frame = reader.get_frame(100)  # 獲取第 100 幀

        # 作為 context manager 使用
        with VideoReader("video.mp4") as reader:
            for i, frame in reader.iterate_frames(0, 100, 5):
                process_frame(frame)
    """

    def __init__(self, video_path: Union[str, Path]):
        """
        初始化影片讀取器

        Args:
            video_path: 影片檔案路徑

        Raises:
            VideoIOError: 當影片初始化失敗時
            FileNotFoundError: 當影片檔案不存在時
        """
        self.path = Path(video_path)
        if not self.path.exists():
            raise FileNotFoundError(f"影片檔案不存在: {video_path}")


        try:
            # 開啟影片容器
            self.container = av.open(str(self.path))
            self.video_stream = self.container.streams.video[0]

            # 計算影格索引轉換參數
            self.fps = float(self.video_stream.average_rate)
            time_base = self.video_stream.time_base
            self.pts2index = time_base * self.video_stream.average_rate

            # 獲取影片尺寸
            self.width = self.video_stream.width
            self.height = self.video_stream.height

            # 計算總影格數(需要特殊處理以確保準確性)
            self.frame_count = self._calculate_frame_count()
            self.duration = self.frame_count / self.fps if self.fps > 0 else 0

            # 內部狀態
            self._current_index = -1
            self._frame_cache = {}
            self._closed = False

            logger.debug(f"影片讀取器初始化完成: {self.path}")
            logger.debug(f"影片資訊: {self.width}x{self.height}, {self.fps:.2f}fps, {self.frame_count} frames")

        except Exception as e:
            raise VideoIOError(f"初始化影片讀取器失敗: {e}")

        # for old version
        self.video_path = self.path
        self.video_name = os.path.basename(self.video_path)
        self.total_frames = self.__len__()

    def __enter__(self) -> 'VideoReader':
        """Context manager 進入"""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        """Context manager 退出"""
        self.close()

    def __len__(self) -> int:
        """返回總影格數量"""
        return self.frame_count

    def __getitem__(self, frame_idx: int) -> np.ndarray:
        """支援 frame = reader[index] 語法"""
        return self.get_frame(frame_idx)

    def __del__(self) -> None:
        """解構函數,確保資源被正確釋放"""
        self.close()

    def _calculate_frame_count(self) -> int:
        """
        計算總影格數量

        直接使用 stream.frames,避免在初始化時讀取影格導致容器狀態混亂。

        Returns:
            總影格數量
        """
        # 首先嘗試使用 stream 的 frames 屬性
        stream_frames = self.video_stream.frames

        if stream_frames and stream_frames > 0:
            # 直接信任 stream.frames,不在初始化時讀取影格
            logger.debug(f"使用 stream.frames: {stream_frames}")
            return stream_frames

        # 如果 stream.frames 不可用,使用二分搜尋法
        logger.warning(f"stream.frames 不可用 ({stream_frames}),使用二分搜尋法計算影格數")
        return self._binary_search_frame_count()

    def _binary_search_frame_count(self) -> int:
        """使用二分搜尋法確定影格數量"""
        return self._binary_search_frame_count_from(0)

    def _binary_search_frame_count_from(self, start_frame: int) -> int:
        """從指定影格開始使用二分搜尋法確定影格數量"""
        low, high = start_frame, 1000000  # 假設最大 1m 幀

        while low < high:
            mid = (low + high + 1) // 2
            try:
                self._get_frame_direct(mid)
                low = mid
            except (RuntimeError, av.error.EOFError, IndexError, Exception):
                # 任何讀取錯誤都表示該影格不存在
                high = mid - 1

        return low + 1

    def _get_frame_direct(self, frame_index: int) -> np.ndarray:
        """
        直接讀取影格(內部方法)

        Args:
            frame_index: 影格索引

        Returns:
            影格陣列 (RGB)

        Raises:
            RuntimeError: 當讀取失敗時
        """
        if self._closed:
            raise RuntimeError("影片讀取器已關閉")

        # 特殊處理影格 0,確保容器位於起始位置
        if frame_index == 0:
            logger.debug(f"正在嘗試直接讀取影格 0")
            try:
                # 重新 seek 到時間 0,確保從頭開始
                self.container.seek(0, stream=self.video_stream, backward=True, any_frame=False)
                # 取得第一個影格
                for frame in self.container.decode(self.video_stream):
                    self._current_index = 0
                    return frame.to_rgb().to_ndarray()
                raise RuntimeError("無法從影片中讀取到第一個影格")
            except Exception as e:
                raise RuntimeError(f"直接讀取影格 0 失敗: {e}")

        # 檢索是否為順序讀取
        if frame_index == self._current_index + 1:
            try:
                self._current_index = frame_index
                frame = next(self.container.decode(self.video_stream))
                return frame.to_rgb().to_ndarray()
            except (StopIteration, av.error.EOFError):
                # 順序讀取失敗(到達檔案結尾或其他原因),改用 seek 方式
                logger.debug(f"順序讀取影格 {frame_index} 失敗,改用 seek 方式")
                pass
            except Exception as e:
                # 其他異常也回退到 seek 方式
                logger.debug(f"順序讀取影格 {frame_index} 時發生異常: {e},改用 seek 方式")
                pass

        # 非順序讀取或順序讀取失敗時,使用 seek 方式
        return self._seek_and_read_frame(frame_index)

    def _seek_and_read_frame(self, frame_index: int) -> np.ndarray:
        """
        使用 seek 方式讀取指定影格

        使用簡化的邏輯,回到與舊版本相似的實現方式。

        Args:
            frame_index: 影格索引

        Returns:
            影格陣列 (RGB)

        Raises:
            RuntimeError: 當讀取失敗時
        """
        try:
            # 計算時間戳並 seek (與舊版本完全相同的邏輯)
            timestamp = frame_index / self.pts2index
            self.container.seek(int(timestamp), stream=self.video_stream, backward=True)

            # 尋找匹配的影格 (簡化邏輯,與舊版本相同)
            target_frame = None

            for frame in self.container.decode(self.video_stream):
                index = int(frame.pts * self.pts2index)
                if index == frame_index:
                    target_frame = frame
                    self._current_index = frame_index
                    break
                # 如果超過目標索引太多,停止搜尋(但給更大的容忍範圍)
                if index > frame_index + 5:
                    break

            if target_frame is not None:
                return target_frame.to_rgb().to_ndarray()

            # 如果精確匹配失敗,嘗試更寬鬆的匹配
            return self._fallback_frame_read(frame_index)

        except (av.error.EOFError, StopIteration):
            # EOF 錯誤,嘗試容錯讀取
            logger.debug(f"seek 讀取影格 {frame_index} 時遇到 EOF,嘗試容錯讀取")
            return self._fallback_frame_read(frame_index)
        except Exception as e:
            raise RuntimeError(f"使用 seek 方式讀取影格 {frame_index} 失敗: {e}")

    def _fallback_frame_read(self, frame_index: int) -> np.ndarray:
        """
        容錯影格讀取:當精確匹配失敗時,嘗試更寬鬆的匹配

        簡化邏輯,只允許很小的誤差範圍,避免返回錯誤的幀。

        Args:
            frame_index: 影格索引

        Returns:
            影格陣列 (RGB)

        Raises:
            RuntimeError: 當讀取失敗時
        """
        try:
            # 重新 seek 到稍微早一點的位置,確保能找到目標幀
            timestamp = (frame_index - 2) / self.pts2index  # 往前 seek 2 幀
            self.container.seek(int(timestamp), stream=self.video_stream, backward=True)

            best_frame = None
            best_distance = float('inf')

            # 寬鬆搜尋,但只接受很小的誤差
            for frame in self.container.decode(self.video_stream):
                try:
                    index = int(frame.pts * self.pts2index)
                    distance = abs(index - frame_index)

                    # 只接受距離 <= 1 的幀(允許 ±1 的誤差)
                    if distance <= 1 and distance < best_distance:
                        best_distance = distance
                        best_frame = frame

                    # 如果找到精確匹配,立即返回
                    if distance == 0:
                        break

                    # 如果已經超過目標太多,停止搜尋
                    if index > frame_index + 5:
                        break

                except (AttributeError, TypeError):
                    # 某些幀可能沒有有效的 pts,跳過
                    continue

            if best_frame is not None:
                self._current_index = frame_index
                if best_distance > 0:
                    logger.debug(f"容錯讀取影格 {frame_index},實際偏移: {best_distance}")
                return best_frame.to_rgb().to_ndarray()

            raise RuntimeError(f"無法找到影格 {frame_index} 或其相近幀")

        except Exception as e:
            raise RuntimeError(f"容錯讀取影格 {frame_index} 失敗: {e}")

    def get_frame(self, frame_idx: int) -> np.ndarray:
        """
        獲取指定影格

        Args:
            frame_idx: 影格索引

        Returns:
            影格陣列 (RGB 格式)

        Raises:
            IndexError: 當影格索引超出範圍時
            RuntimeError: 當讀取失敗時
        """
        if frame_idx < 0 or frame_idx >= self.frame_count:
            raise IndexError(f"影格索引 {frame_idx} 超出範圍 [0, {self.frame_count})")

        # 檢查快取
        if frame_idx in self._frame_cache:
            return self._frame_cache[frame_idx]

        # 讀取影格
        try:
            frame = self._get_frame_direct(frame_idx)
        except RuntimeError as e:
            # 如果讀取失敗且索引接近邊界,可能是 frame_count 計算不準確
            if frame_idx >= self.frame_count - 10:
                logger.warning(f"影格 {frame_idx} 讀取失敗,可能超出實際影片範圍: {e}")
                raise IndexError(f"影格索引 {frame_idx} 可能超出實際影片範圍")
            else:
                # 其他錯誤直接拋出
                raise

        # 快取管理(限制快取大小)
        if len(self._frame_cache) >= 100:
            # 移除最舊的快取項目
            oldest_key = next(iter(self._frame_cache))
            del self._frame_cache[oldest_key]

        self._frame_cache[frame_idx] = frame
        logger.debug(f"成功讀取並快取影格 {frame_idx}")

        return frame

    def iterate_frames(
        self,
        start: int = 0,
        end: Optional[int] = None,
        step: int = 1
    ) -> Generator[Tuple[int, np.ndarray], None, None]:
        """
        迭代影片影格

        Args:
            start: 起始影格索引
            end: 結束影格索引(不包含),None 表示到影片結尾
            step: 步長

        Yields:
            (影格索引, 影格陣列) 元組

        Raises:
            ValueError: 當參數無效時
        """
        if end is None:
            end = self.frame_count

        if start < 0 or start >= self.frame_count:
            raise ValueError(f"起始索引 {start} 無效")

        if end > self.frame_count:
            end = self.frame_count

        if step <= 0:
            raise ValueError("步長必須為正數")

        for i in range(start, end, step):
            try:
                frame = self.get_frame(i)
                yield i, frame
            except Exception as e:
                logger.warning(f"跳過影格 {i}: {e}")
                continue

    def get_batch_frames(self, indices: List[int]) -> List[np.ndarray]:
        """
        批次獲取影格

        Args:
            indices: 影格索引列表

        Returns:
            影格陣列列表

        Raises:
            ValueError: 當索引列表為空時
        """
        if not indices:
            raise ValueError("索引列表不能為空")

        frames = []
        for idx in indices:
            try:
                frame = self.get_frame(idx)
                frames.append(frame)
            except Exception as e:
                logger.warning(f"跳過影格 {idx}: {e}")
                # 添加空影格佔位符(與原索引對應)
                frames.append(np.zeros((self.height, self.width, 3), dtype=np.uint8))

        return frames

    def get_info(self) -> VideoInfo:
        """
        獲取影片資訊

        Returns:
            VideoInfo 實例,包含影片的基本資訊
        """
        return VideoInfo(
            path=self.path,
            fps=self.fps,
            frame_count=self.frame_count,
            width=self.width,
            height=self.height,
            duration=self.duration
        )

    def clear_cache(self) -> None:
        """清除影格快取"""
        self._frame_cache.clear()
        logger.debug("影格快取已清除")

    def close(self) -> None:
        """關閉影片讀取器並釋放資源"""
        if not self._closed:
            try:
                if hasattr(self, 'container') and self.container:
                    self.container.close()
                self._frame_cache.clear()
                self._closed = True
                logger.debug("影片讀取器已關閉")
            except Exception as e:
                logger.error(f"關閉影片讀取器時發生錯誤: {e}")

__del__()

解構函數,確保資源被正確釋放

Source code in castle/utils/video_io.py
def __del__(self) -> None:
    """解構函數,確保資源被正確釋放"""
    self.close()

__enter__()

Context manager 進入

Source code in castle/utils/video_io.py
def __enter__(self) -> 'VideoReader':
    """Context manager 進入"""
    return self

__exit__(exc_type, exc_val, exc_tb)

Context manager 退出

Source code in castle/utils/video_io.py
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
    """Context manager 退出"""
    self.close()

__getitem__(frame_idx)

支援 frame = reader[index] 語法

Source code in castle/utils/video_io.py
def __getitem__(self, frame_idx: int) -> np.ndarray:
    """支援 frame = reader[index] 語法"""
    return self.get_frame(frame_idx)

__init__(video_path)

初始化影片讀取器

Parameters:

Name Type Description Default
video_path Union[str, Path]

影片檔案路徑

required

Raises:

Type Description
VideoIOError

當影片初始化失敗時

FileNotFoundError

當影片檔案不存在時

Source code in castle/utils/video_io.py
def __init__(self, video_path: Union[str, Path]):
    """
    初始化影片讀取器

    Args:
        video_path: 影片檔案路徑

    Raises:
        VideoIOError: 當影片初始化失敗時
        FileNotFoundError: 當影片檔案不存在時
    """
    self.path = Path(video_path)
    if not self.path.exists():
        raise FileNotFoundError(f"影片檔案不存在: {video_path}")


    try:
        # 開啟影片容器
        self.container = av.open(str(self.path))
        self.video_stream = self.container.streams.video[0]

        # 計算影格索引轉換參數
        self.fps = float(self.video_stream.average_rate)
        time_base = self.video_stream.time_base
        self.pts2index = time_base * self.video_stream.average_rate

        # 獲取影片尺寸
        self.width = self.video_stream.width
        self.height = self.video_stream.height

        # 計算總影格數(需要特殊處理以確保準確性)
        self.frame_count = self._calculate_frame_count()
        self.duration = self.frame_count / self.fps if self.fps > 0 else 0

        # 內部狀態
        self._current_index = -1
        self._frame_cache = {}
        self._closed = False

        logger.debug(f"影片讀取器初始化完成: {self.path}")
        logger.debug(f"影片資訊: {self.width}x{self.height}, {self.fps:.2f}fps, {self.frame_count} frames")

    except Exception as e:
        raise VideoIOError(f"初始化影片讀取器失敗: {e}")

    # for old version
    self.video_path = self.path
    self.video_name = os.path.basename(self.video_path)
    self.total_frames = self.__len__()

__len__()

返回總影格數量

Source code in castle/utils/video_io.py
def __len__(self) -> int:
    """返回總影格數量"""
    return self.frame_count

clear_cache()

清除影格快取

Source code in castle/utils/video_io.py
def clear_cache(self) -> None:
    """清除影格快取"""
    self._frame_cache.clear()
    logger.debug("影格快取已清除")

close()

關閉影片讀取器並釋放資源

Source code in castle/utils/video_io.py
def close(self) -> None:
    """關閉影片讀取器並釋放資源"""
    if not self._closed:
        try:
            if hasattr(self, 'container') and self.container:
                self.container.close()
            self._frame_cache.clear()
            self._closed = True
            logger.debug("影片讀取器已關閉")
        except Exception as e:
            logger.error(f"關閉影片讀取器時發生錯誤: {e}")

get_batch_frames(indices)

批次獲取影格

Parameters:

Name Type Description Default
indices List[int]

影格索引列表

required

Returns:

Type Description
List[ndarray]

影格陣列列表

Raises:

Type Description
ValueError

當索引列表為空時

Source code in castle/utils/video_io.py
def get_batch_frames(self, indices: List[int]) -> List[np.ndarray]:
    """
    批次獲取影格

    Args:
        indices: 影格索引列表

    Returns:
        影格陣列列表

    Raises:
        ValueError: 當索引列表為空時
    """
    if not indices:
        raise ValueError("索引列表不能為空")

    frames = []
    for idx in indices:
        try:
            frame = self.get_frame(idx)
            frames.append(frame)
        except Exception as e:
            logger.warning(f"跳過影格 {idx}: {e}")
            # 添加空影格佔位符(與原索引對應)
            frames.append(np.zeros((self.height, self.width, 3), dtype=np.uint8))

    return frames

get_frame(frame_idx)

獲取指定影格

Parameters:

Name Type Description Default
frame_idx int

影格索引

required

Returns:

Type Description
ndarray

影格陣列 (RGB 格式)

Raises:

Type Description
IndexError

當影格索引超出範圍時

RuntimeError

當讀取失敗時

Source code in castle/utils/video_io.py
def get_frame(self, frame_idx: int) -> np.ndarray:
    """
    獲取指定影格

    Args:
        frame_idx: 影格索引

    Returns:
        影格陣列 (RGB 格式)

    Raises:
        IndexError: 當影格索引超出範圍時
        RuntimeError: 當讀取失敗時
    """
    if frame_idx < 0 or frame_idx >= self.frame_count:
        raise IndexError(f"影格索引 {frame_idx} 超出範圍 [0, {self.frame_count})")

    # 檢查快取
    if frame_idx in self._frame_cache:
        return self._frame_cache[frame_idx]

    # 讀取影格
    try:
        frame = self._get_frame_direct(frame_idx)
    except RuntimeError as e:
        # 如果讀取失敗且索引接近邊界,可能是 frame_count 計算不準確
        if frame_idx >= self.frame_count - 10:
            logger.warning(f"影格 {frame_idx} 讀取失敗,可能超出實際影片範圍: {e}")
            raise IndexError(f"影格索引 {frame_idx} 可能超出實際影片範圍")
        else:
            # 其他錯誤直接拋出
            raise

    # 快取管理(限制快取大小)
    if len(self._frame_cache) >= 100:
        # 移除最舊的快取項目
        oldest_key = next(iter(self._frame_cache))
        del self._frame_cache[oldest_key]

    self._frame_cache[frame_idx] = frame
    logger.debug(f"成功讀取並快取影格 {frame_idx}")

    return frame

get_info()

獲取影片資訊

Returns:

Type Description
VideoInfo

VideoInfo 實例,包含影片的基本資訊

Source code in castle/utils/video_io.py
def get_info(self) -> VideoInfo:
    """
    獲取影片資訊

    Returns:
        VideoInfo 實例,包含影片的基本資訊
    """
    return VideoInfo(
        path=self.path,
        fps=self.fps,
        frame_count=self.frame_count,
        width=self.width,
        height=self.height,
        duration=self.duration
    )

iterate_frames(start=0, end=None, step=1)

迭代影片影格

Parameters:

Name Type Description Default
start int

起始影格索引

0
end Optional[int]

結束影格索引(不包含),None 表示到影片結尾

None
step int

步長

1

Yields:

Type Description
Tuple[int, ndarray]

(影格索引, 影格陣列) 元組

Raises:

Type Description
ValueError

當參數無效時

Source code in castle/utils/video_io.py
def iterate_frames(
    self,
    start: int = 0,
    end: Optional[int] = None,
    step: int = 1
) -> Generator[Tuple[int, np.ndarray], None, None]:
    """
    迭代影片影格

    Args:
        start: 起始影格索引
        end: 結束影格索引(不包含),None 表示到影片結尾
        step: 步長

    Yields:
        (影格索引, 影格陣列) 元組

    Raises:
        ValueError: 當參數無效時
    """
    if end is None:
        end = self.frame_count

    if start < 0 or start >= self.frame_count:
        raise ValueError(f"起始索引 {start} 無效")

    if end > self.frame_count:
        end = self.frame_count

    if step <= 0:
        raise ValueError("步長必須為正數")

    for i in range(start, end, step):
        try:
            frame = self.get_frame(i)
            yield i, frame
        except Exception as e:
            logger.warning(f"跳過影格 {i}: {e}")
            continue

VideoWriter

影片寫入器

使用 av 庫實現高效的影片編碼和寫入功能。 基於提供的 WriteArray 範例實現,支援多種編碼器和品質設定。

範例用法: # 基本使用 writer = VideoWriter("output.mp4", fps=30.0) for frame in frames: writer.write_frame(frame) writer.close()

# 作為 context manager 使用
with VideoWriter("output.mp4", fps=30.0, crf=20) as writer:
    for frame in frames:
        writer.write_frame(frame)
Source code in castle/utils/video_io.py
class VideoWriter:
    """
    影片寫入器

    使用 av 庫實現高效的影片編碼和寫入功能。
    基於提供的 WriteArray 範例實現,支援多種編碼器和品質設定。

    範例用法:
        # 基本使用
        writer = VideoWriter("output.mp4", fps=30.0)
        for frame in frames:
            writer.write_frame(frame)
        writer.close()

        # 作為 context manager 使用
        with VideoWriter("output.mp4", fps=30.0, crf=20) as writer:
            for frame in frames:
                writer.write_frame(frame)
    """

    def __init__(
        self,
        output_path: Union[str, Path],
        fps: float = 30.0,
        crf: int = 15,
        codec: str = 'libx264'
    ):
        """
        初始化影片寫入器

        Args:
            output_path: 輸出檔案路徑
            fps: 幀率
            crf: 壓縮品質 (0-51, 數值越小品質越好,檔案越大)
            codec: 編碼器名稱

        Raises:
            VideoIOError: 當初始化失敗時
        """
        try:
            self.output_path = Path(output_path)
            self.output_path.parent.mkdir(parents=True, exist_ok=True)

            self.output = av.open(str(self.output_path), 'w')
            # 轉換 fps 為 Fraction 對象以支援 PyAV 15.1.0
            fps_fraction = Fraction(fps).limit_denominator()
            self.stream = self.output.add_stream(codec, rate=fps_fraction)
            self.stream.options = {'crf': str(crf)}
            self.stream.pix_fmt = 'yuv420p'

            self._initialized = False
            self._closed = False
            self._frame_count = 0

            logger.debug(f"影片寫入器初始化完成: {self.output_path}")

        except Exception as e:
            raise VideoIOError(f"初始化影片寫入器失敗: {e}")

        # for old version
        self.video_path = output_path
        self.video_name = os.path.basename(self.video_path)


    def __enter__(self) -> 'VideoWriter':
        """Context manager 進入"""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        """Context manager 退出"""
        self.close()

    def __del__(self) -> None:
        """解構函數,確保資源被正確釋放"""
        self.close()

    def write_frame(self, frame: np.ndarray) -> None:
        """
        寫入一個影格

        Args:
            frame: 影格陣列,形狀為 (H, W, 3),RGB 格式

        Raises:
            ValueError: 當影格格式無效時
            VideoIOError: 當寫入失敗時
        """
        if self._closed:
            raise VideoIOError("影片寫入器已關閉")

        if not isinstance(frame, np.ndarray):
            raise ValueError("影格必須是 numpy 陣列")

        if frame.ndim != 3 or frame.shape[2] != 3:
            raise ValueError("預期影格形狀為 (H, W, 3)")

        try:
            # 首次寫入時設定影片尺寸
            if not self._initialized:
                self.stream.height, self.stream.width = frame.shape[:2]
                self._initialized = True
                logger.debug(f"設定影片尺寸: {self.stream.width}x{self.stream.height}")

            # 確保影格格式正確
            if frame.dtype != np.uint8:
                if frame.dtype == np.float32 or frame.dtype == np.float64:
                    frame = (frame * 255).astype(np.uint8)
                else:
                    frame = frame.astype(np.uint8)

            # 建立 VideoFrame 並編碼
            av_frame = av.VideoFrame.from_ndarray(frame, format='rgb24')

            for packet in self.stream.encode(av_frame):
                self.output.mux(packet)

            self._frame_count += 1

            if self._frame_count % 100 == 0:
                logger.debug(f"已寫入 {self._frame_count} 個影格")

        except Exception as e:
            raise VideoIOError(f"寫入影格失敗: {e}")

    def append(self, frame: np.ndarray) -> None:
        """
        追加影格(向後相容性方法)

        這是為了與舊版本 WriteArray 相容而提供的方法。
        內部調用 write_frame 方法。

        Args:
            frame: 影格陣列,形狀為 (H, W, 3),RGB 格式

        Raises:
            ValueError: 當影格格式無效時
            VideoIOError: 當寫入失敗時
        """
        self.write_frame(frame)

    def close(self) -> None:
        """關閉影片寫入器並完成編碼"""
        if self._closed:
            return

        try:
            # 刷新編碼器緩衝區
            if hasattr(self, 'stream'):
                for packet in self.stream.encode():
                    self.output.mux(packet)

            # 關閉輸出檔案
            if hasattr(self, 'output'):
                self.output.close()

            self._closed = True
            logger.info(f"影片寫入完成: {self.output_path}, 總共 {self._frame_count} 個影格")

        except Exception as e:
            logger.error(f"關閉影片寫入器時發生錯誤: {e}")

__del__()

解構函數,確保資源被正確釋放

Source code in castle/utils/video_io.py
def __del__(self) -> None:
    """解構函數,確保資源被正確釋放"""
    self.close()

__enter__()

Context manager 進入

Source code in castle/utils/video_io.py
def __enter__(self) -> 'VideoWriter':
    """Context manager 進入"""
    return self

__exit__(exc_type, exc_val, exc_tb)

Context manager 退出

Source code in castle/utils/video_io.py
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
    """Context manager 退出"""
    self.close()

__init__(output_path, fps=30.0, crf=15, codec='libx264')

初始化影片寫入器

Parameters:

Name Type Description Default
output_path Union[str, Path]

輸出檔案路徑

required
fps float

幀率

30.0
crf int

壓縮品質 (0-51, 數值越小品質越好,檔案越大)

15
codec str

編碼器名稱

'libx264'

Raises:

Type Description
VideoIOError

當初始化失敗時

Source code in castle/utils/video_io.py
def __init__(
    self,
    output_path: Union[str, Path],
    fps: float = 30.0,
    crf: int = 15,
    codec: str = 'libx264'
):
    """
    初始化影片寫入器

    Args:
        output_path: 輸出檔案路徑
        fps: 幀率
        crf: 壓縮品質 (0-51, 數值越小品質越好,檔案越大)
        codec: 編碼器名稱

    Raises:
        VideoIOError: 當初始化失敗時
    """
    try:
        self.output_path = Path(output_path)
        self.output_path.parent.mkdir(parents=True, exist_ok=True)

        self.output = av.open(str(self.output_path), 'w')
        # 轉換 fps 為 Fraction 對象以支援 PyAV 15.1.0
        fps_fraction = Fraction(fps).limit_denominator()
        self.stream = self.output.add_stream(codec, rate=fps_fraction)
        self.stream.options = {'crf': str(crf)}
        self.stream.pix_fmt = 'yuv420p'

        self._initialized = False
        self._closed = False
        self._frame_count = 0

        logger.debug(f"影片寫入器初始化完成: {self.output_path}")

    except Exception as e:
        raise VideoIOError(f"初始化影片寫入器失敗: {e}")

    # for old version
    self.video_path = output_path
    self.video_name = os.path.basename(self.video_path)

append(frame)

追加影格(向後相容性方法)

這是為了與舊版本 WriteArray 相容而提供的方法。 內部調用 write_frame 方法。

Parameters:

Name Type Description Default
frame ndarray

影格陣列,形狀為 (H, W, 3),RGB 格式

required

Raises:

Type Description
ValueError

當影格格式無效時

VideoIOError

當寫入失敗時

Source code in castle/utils/video_io.py
def append(self, frame: np.ndarray) -> None:
    """
    追加影格(向後相容性方法)

    這是為了與舊版本 WriteArray 相容而提供的方法。
    內部調用 write_frame 方法。

    Args:
        frame: 影格陣列,形狀為 (H, W, 3),RGB 格式

    Raises:
        ValueError: 當影格格式無效時
        VideoIOError: 當寫入失敗時
    """
    self.write_frame(frame)

close()

關閉影片寫入器並完成編碼

Source code in castle/utils/video_io.py
def close(self) -> None:
    """關閉影片寫入器並完成編碼"""
    if self._closed:
        return

    try:
        # 刷新編碼器緩衝區
        if hasattr(self, 'stream'):
            for packet in self.stream.encode():
                self.output.mux(packet)

        # 關閉輸出檔案
        if hasattr(self, 'output'):
            self.output.close()

        self._closed = True
        logger.info(f"影片寫入完成: {self.output_path}, 總共 {self._frame_count} 個影格")

    except Exception as e:
        logger.error(f"關閉影片寫入器時發生錯誤: {e}")

write_frame(frame)

寫入一個影格

Parameters:

Name Type Description Default
frame ndarray

影格陣列,形狀為 (H, W, 3),RGB 格式

required

Raises:

Type Description
ValueError

當影格格式無效時

VideoIOError

當寫入失敗時

Source code in castle/utils/video_io.py
def write_frame(self, frame: np.ndarray) -> None:
    """
    寫入一個影格

    Args:
        frame: 影格陣列,形狀為 (H, W, 3),RGB 格式

    Raises:
        ValueError: 當影格格式無效時
        VideoIOError: 當寫入失敗時
    """
    if self._closed:
        raise VideoIOError("影片寫入器已關閉")

    if not isinstance(frame, np.ndarray):
        raise ValueError("影格必須是 numpy 陣列")

    if frame.ndim != 3 or frame.shape[2] != 3:
        raise ValueError("預期影格形狀為 (H, W, 3)")

    try:
        # 首次寫入時設定影片尺寸
        if not self._initialized:
            self.stream.height, self.stream.width = frame.shape[:2]
            self._initialized = True
            logger.debug(f"設定影片尺寸: {self.stream.width}x{self.stream.height}")

        # 確保影格格式正確
        if frame.dtype != np.uint8:
            if frame.dtype == np.float32 or frame.dtype == np.float64:
                frame = (frame * 255).astype(np.uint8)
            else:
                frame = frame.astype(np.uint8)

        # 建立 VideoFrame 並編碼
        av_frame = av.VideoFrame.from_ndarray(frame, format='rgb24')

        for packet in self.stream.encode(av_frame):
            self.output.mux(packet)

        self._frame_count += 1

        if self._frame_count % 100 == 0:
            logger.debug(f"已寫入 {self._frame_count} 個影格")

    except Exception as e:
        raise VideoIOError(f"寫入影格失敗: {e}")

Video Management

castle.utils.video_manager

Video management utilities for Castle AI.

add_video_to_project(storage_path, project_name, video_source_path, video_name)

Add a video file to the project.

Parameters:

Name Type Description Default
storage_path

Path to the storage directory

required
project_name

Name of the project

required
video_source_path

Source path of the video file

required
video_name

Name of the video file

required

Returns:

Name Type Description
tuple

(success: bool, message: str)

Source code in castle/utils/video_manager.py
def add_video_to_project(storage_path, project_name, video_source_path, video_name):
    """Add a video file to the project.

    Args:
        storage_path: Path to the storage directory
        project_name: Name of the project
        video_source_path: Source path of the video file
        video_name: Name of the video file

    Returns:
        tuple: (success: bool, message: str)
    """
    try:
        project_path, config = get_project_config(storage_path, project_name)

        # Initialize source list if not exists
        if 'source' not in config:
            config['source'] = []

        # Check if video already exists in project
        if video_name in config['source']:
            return False, "Video already exists in this project"

        # Create sources directory if not exists
        source_dir_path = os.path.join(project_path, 'sources')
        os.makedirs(source_dir_path, exist_ok=True)

        # Copy video file to project
        destination_path = os.path.join(source_dir_path, video_name)
        shutil.copyfile(video_source_path, destination_path)

        # Update configuration
        config['source'].append(video_name)
        save_project_config(storage_path, project_name, config)

        return True, f"Successfully added video: {video_name}"

    except FileNotFoundError as e:
        return False, f"File not found: {str(e)}"
    except Exception as e:
        return False, f"Failed to add video: {str(e)}"

add_videos_batch(storage_path, project_name, video_directory, video_list)

Add multiple videos to a project in batch.

Parameters:

Name Type Description Default
storage_path

Path to the storage directory

required
project_name

Name of the project

required
video_directory

Source directory containing videos

required
video_list

List of video file names to add

required

Returns:

Name Type Description
tuple

(success_count, fail_count, messages)

Source code in castle/utils/video_manager.py
def add_videos_batch(storage_path, project_name, video_directory, video_list):
    """Add multiple videos to a project in batch.

    Args:
        storage_path: Path to the storage directory
        project_name: Name of the project
        video_directory: Source directory containing videos
        video_list: List of video file names to add

    Returns:
        tuple: (success_count, fail_count, messages)
    """
    if not video_list:
        return 0, 0, ["No videos to add"]

    success_count = 0
    fail_count = 0
    messages = []

    for video_name in video_list:
        video_source_path = os.path.join(video_directory, video_name)
        success, message = add_video_to_project(
            storage_path,
            project_name,
            video_source_path,
            video_name
        )

        if success:
            success_count += 1
        else:
            fail_count += 1
            messages.append(f"Failed: {video_name} - {message}")

    # Add summary message
    summary = f"Added {success_count} video(s) successfully"
    if fail_count > 0:
        summary += f", {fail_count} failed"
    messages.insert(0, summary)

    return success_count, fail_count, messages

get_project_videos(storage_path, project_name)

Get list of videos in a project.

Parameters:

Name Type Description Default
storage_path

Path to the storage directory

required
project_name

Name of the project

required

Returns:

Name Type Description
list

List of video file names in the project, sorted

Source code in castle/utils/video_manager.py
def get_project_videos(storage_path, project_name):
    """Get list of videos in a project.

    Args:
        storage_path: Path to the storage directory
        project_name: Name of the project

    Returns:
        list: List of video file names in the project, sorted
    """
    try:
        _, config = get_project_config(storage_path, project_name)
        videos = config.get('source', [])
        return sorted(videos) if videos else []
    except Exception as e:
        logger.error(f"Error getting project videos: {e}")
        return []

is_video_file(file_path)

Check if a file is a video based on its extension.

Parameters:

Name Type Description Default
file_path

Path to the file

required

Returns:

Name Type Description
bool

True if the file is a video, False otherwise

Source code in castle/utils/video_manager.py
def is_video_file(file_path):
    """Check if a file is a video based on its extension.

    Args:
        file_path: Path to the file

    Returns:
        bool: True if the file is a video, False otherwise
    """
    _, ext = os.path.splitext(file_path)
    return ext.lower() in VIDEO_EXTENSIONS

list_videos_in_directory(directory_path)

List all video files in a directory.

Parameters:

Name Type Description Default
directory_path

Path to the directory

required

Returns:

Name Type Description
list

Sorted list of video file names

Source code in castle/utils/video_manager.py
def list_videos_in_directory(directory_path):
    """List all video files in a directory.

    Args:
        directory_path: Path to the directory

    Returns:
        list: Sorted list of video file names
    """
    if not os.path.exists(directory_path):
        return []

    try:
        videos = [
            f for f in os.listdir(directory_path)
            if os.path.isfile(os.path.join(directory_path, f)) and is_video_file(f)
        ]
        return sorted(videos)
    except Exception as e:
        logger.error(f"Error listing videos: {e}")
        return []

load_video_metadata(storage_path, project_name, video_name)

Load video file and return metadata.

Parameters:

Name Type Description Default
storage_path

Path to the storage directory

required
project_name

Name of the project

required
video_name

Name of the video file

required

Returns:

Name Type Description
tuple

(video_path, frame_count) or (None, 0) if error

Source code in castle/utils/video_manager.py
def load_video_metadata(storage_path, project_name, video_name):
    """Load video file and return metadata.

    Args:
        storage_path: Path to the storage directory
        project_name: Name of the project
        video_name: Name of the video file

    Returns:
        tuple: (video_path, frame_count) or (None, 0) if error
    """
    try:
        from castle.utils.video_io import ReadArray

        video_path = os.path.join(storage_path, project_name, 'sources', video_name)

        if not os.path.exists(video_path):
            logger.warning(f"Video file not found: {video_path}")
            return None, 0

        # Load video to get frame count
        video_array = ReadArray(video_path)
        frame_count = len(video_array)

        return video_path, frame_count

    except Exception as e:
        logger.error(f"Error loading video metadata: {e}")
        return None, 0

scan_video_directory(directory_path)

Scan a directory and return video statistics.

Parameters:

Name Type Description Default
directory_path

Path to the directory to scan

required

Returns:

Name Type Description
tuple

(video_list, summary_text) where video_list is the list of all videos and summary_text is a formatted string showing statistics

Source code in castle/utils/video_manager.py
def scan_video_directory(directory_path):
    """Scan a directory and return video statistics.

    Args:
        directory_path: Path to the directory to scan

    Returns:
        tuple: (video_list, summary_text) where video_list is the list of all videos
               and summary_text is a formatted string showing statistics
    """
    if not directory_path or not os.path.exists(directory_path):
        return [], "Directory not found or invalid path"

    try:
        videos = list_videos_in_directory(directory_path)

        if not videos:
            return [], "No video files found in this directory"

        # Create summary text
        video_count = len(videos)
        preview_count = min(5, video_count)
        preview_videos = videos[:preview_count]

        summary = f"Found {video_count} video file(s)\n\n"
        summary += "Preview (first 5 files):\n"
        for i, video in enumerate(preview_videos, 1):
            summary += f"{i}. {video}\n"

        if video_count > 5:
            summary += f"\n... and {video_count - 5} more file(s)"

        return videos, summary

    except Exception as e:
        return [], f"Error scanning directory: {str(e)}"

Image Segmentation (SAM)

castle.utils.image_segment

Segmentor

Source code in castle/utils/image_segment.py
class Segmentor:
    def __init__(self, sam_args):
        """
        sam_args:
            sam_checkpoint: path of SAM checkpoint
            generator_args: args for everything_generator
            device: device
        """
        self.device = sam_args["device"]
        self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"])
        self.sam.to(device=self.device)
        self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args'])
        self.interactive_predictor = self.everything_generator.predictor
        self.have_embedded = False

    @torch.no_grad()
    def set_image(self, image):
        if not self.have_embedded:
            self.interactive_predictor.set_image(image)
            self.have_embedded = True

    @torch.no_grad()
    def interactive_predict(self, prompts, mode, multimask=True):
        assert self.have_embedded, 'image embedding for sam need be set before predict.'        

        if mode == 'point':
            masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], 
                                point_labels=prompts['point_modes'], 
                                multimask_output=multimask)
        elif mode == 'mask':
            masks, scores, logits = self.interactive_predictor.predict(mask_input=prompts['mask_prompt'], 
                                multimask_output=multimask)
        elif mode == 'point_mask':
            masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], 
                                point_labels=prompts['point_modes'], 
                                mask_input=prompts['mask_prompt'], 
                                multimask_output=multimask)

        return masks, scores, logits

    @torch.no_grad()
    def segment_with_click(self, origin_frame, coords, modes, multimask=True):
        '''

            return: 
                mask: one-hot 
        '''
        self.set_image(origin_frame)

        prompts = {
            'point_coords': coords,
            'point_modes': modes,
        }
        masks, scores, logits = self.interactive_predict(prompts, 'point', multimask)
        mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
        prompts = {
            'point_coords': coords,
            'point_modes': modes,
            'mask_prompt': logit[None, :, :]
        }
        masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask)
        mask = masks[np.argmax(scores)]

        return mask.astype(np.uint8)

__init__(sam_args)

sam_args

sam_checkpoint: path of SAM checkpoint generator_args: args for everything_generator device: device

Source code in castle/utils/image_segment.py
def __init__(self, sam_args):
    """
    sam_args:
        sam_checkpoint: path of SAM checkpoint
        generator_args: args for everything_generator
        device: device
    """
    self.device = sam_args["device"]
    self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"])
    self.sam.to(device=self.device)
    self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args'])
    self.interactive_predictor = self.everything_generator.predictor
    self.have_embedded = False

segment_with_click(origin_frame, coords, modes, multimask=True)

return

mask: one-hot

Source code in castle/utils/image_segment.py
@torch.no_grad()
def segment_with_click(self, origin_frame, coords, modes, multimask=True):
    '''

        return: 
            mask: one-hot 
    '''
    self.set_image(origin_frame)

    prompts = {
        'point_coords': coords,
        'point_modes': modes,
    }
    masks, scores, logits = self.interactive_predict(prompts, 'point', multimask)
    mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
    prompts = {
        'point_coords': coords,
        'point_modes': modes,
        'mask_prompt': logit[None, :, :]
    }
    masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask)
    mask = masks[np.argmax(scores)]

    return mask.astype(np.uint8)

ROI Tracking

castle.utils.tracking_manager

ROI tracking management utilities for Castle AI.

ROITracker

ROI tracker for performing video object tracking using reference frames and masks.

Source code in castle/utils/tracking_manager.py
class ROITracker:
    """ROI tracker for performing video object tracking using reference frames and masks."""

    def __init__(
        self,
        storage_path: str,
        project_name: str,
        video_source: Any,
        start_frame: int,
        stop_frame: int,
        model_type: str = "r50_deaotl",
    ) -> None:
        """Initialize the ROI tracker.

        Args:
            storage_path: Base storage directory
            project_name: Name of the project
            video_source: Video source object
            start_frame: Starting frame index
            stop_frame: Stopping frame index
            model_type: Tracking model type (e.g., 'r50_deaotl', 'swinb_deaotl')
        """
        self.cancel = False
        self.show_middle_result = False
        self.model_type = model_type

        # Setup paths
        project_path = Path(storage_path) / project_name
        # video_name = video_source.video_name
        self.track_dir = project_path / "track" / video_source.video_name
        self.track_dir.mkdir(parents=True, exist_ok=True)

        # Video parameters
        self.video_source = video_source
        self.start_frame = int(start_frame)
        self.stop_frame = int(stop_frame)
        self.max_memory_length = 30

        # Load reference knowledge from labels
        self.reference_frames = []
        label_list = read_roi_labels(storage_path, project_name) # 移除 video_source.video_name 參數
        self.n_rois = 0

        for label in label_list:
            frame, mask = label["frame"], label["mask"]
            self.reference_frames.append((frame, mask))
            # Update n_rois to be the maximum value found in masks
            self.n_rois = max(self.n_rois, int(np.max(mask)))

        # Current frame and mask for display
        self.current_frame = None
        self.current_mask = None

        # --- Smart Filtering Initialization ---
        self.smart_thresholds = {}
        for (_, mask) in self.reference_frames:
            obj_ids = np.unique(mask)
            for obj_id in obj_ids:
                if obj_id == 0: continue
                area = np.sum(mask == obj_id)
                # If multiple references, we take the one that gives us a reasonable baseline.
                # Here we just blindly set/overwrite, assuming references are consistent.
                # A safer bet is 10% of the reference area.
                self.smart_thresholds[obj_id] = area * 0.1
        print(f"Smart Filtering Thresholds: {self.smart_thresholds}")

    def _smart_filter(self, mask: np.ndarray) -> np.ndarray:
        """Apply automated smart filtering: Keep Largest Component (> threshold)."""
        if mask.max() == 0: return mask

        new_mask = np.zeros_like(mask)
        obj_ids = np.unique(mask)

        for obj_id in obj_ids:
            if obj_id == 0: continue

            # Binary mask for this object
            binary_mask = (mask == obj_id).astype(np.uint8)

            # Connectivity Analysis
            # num_labels includes background (0)
            num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)

            if num_labels <= 1: continue 

            threshold = self.smart_thresholds.get(obj_id, 50) # Default 50 pixels if untracked

            # Find largest component that meets threshold
            largest_area = -1
            largest_label = -1

            for i in range(1, num_labels):
                area = stats[i, cv2.CC_STAT_AREA]
                if area > threshold and area > largest_area:
                    largest_area = area
                    largest_label = i

            # If we found a valid component, keep it
            if largest_label != -1:
                new_mask[labels == largest_label] = obj_id

        return new_mask

    def track(self, progress: Optional[gr.Progress] = None, skip_existing: bool = False) -> str:
        """Execute ROI tracking over specified frames using a parallelized DataLoader and batch inference."""
        time.sleep(0.5)

        # Initialize tracker model and HDF5 writer
        tracker = generate_aot(model_type=self.model_type)
        mask_list_path = self.track_dir / "mask_list.h5"

        # --- Start of new logic: Ensure a clean HDF5 file ---
        if os.path.exists(mask_list_path):
            if skip_existing:
                print(f"Skipping existing tracked file: {mask_list_path}")
                return "Skip"

            try:
                os.remove(mask_list_path)
                print(f"Removed existing HDF5 file: {mask_list_path}")
            except Exception as e:
                print(f"Warning: Could not remove existing HDF5 file {mask_list_path}. Error: {e}")
        # --- End of new logic ---

        mask_seq = H5IO(str(mask_list_path))

        # Write video and ROI configuration
        first_frame = self.video_source[0]
        mask_seq.write_config("n_rois", self.n_rois)
        mask_seq.write_config("total_frames", len(self.video_source))
        mask_seq.write_config("height", first_frame.shape[0])
        mask_seq.write_config("width", first_frame.shape[1])

        # Add all reference ROI frames to tracker's memory
        for frame, mask in self.reference_frames:
            tracker.add_reference_frame(frame, mask, self.n_rois, -1)

        # Determine tracking direction
        delta = 1 if self.start_frame < self.stop_frame else -1
        frame_range = list(range(self.start_frame, self.stop_frame + delta, delta))

        # Initialize DataLoader for batch processing
        # Initialize DataLoader for batch processing
        # Limit num_workers to prevent OOM (20% of CPU cores)
        num_workers = max(1, int(os.cpu_count() * 0.2))
        batch_size = 16  # Process frames in batches
        print(f"DEBUG: Tracking with {num_workers} workers and batch size {batch_size}")

        dataset = TrackingDataset(self.video_source, frame_range, tracker.transform)

        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )


        # Iterator with progress bar, handling custom notification callback if present
        if progress is not None:
            iterator = progress.tqdm(loader, desc="Tracking frames")
        else:
            iterator = tqdm(loader, desc="Tracking frames")

        for frame_tensors, frame_indices, original_frames in iterator:
            # Check for cancellation flag
            if self.cancel:
                self.show_middle_result = False
                self.cancel = False
                del mask_seq
                return "Cancel"

            # Prepare batch of original sizes
            original_sizes = [frame.shape[:2] for frame in original_frames.numpy()]

            # Perform batch tracking
            mask_batch = tracker.track_batch(frame_tensors, original_sizes=original_sizes)

            # Process and save the batch of masks
            processed_masks = mask_batch.squeeze(1).detach().cpu().numpy().astype(np.uint8)

            for i in range(len(processed_masks)):
                frame_idx = frame_indices[i].item()
                mask_to_save = processed_masks[i]

                # Apply Smart Filtering (Keep Largest Component per class)
                mask_to_save = self._smart_filter(mask_to_save)

                # Update current state for display (with the last frame of the batch)
                self.current_frame = original_frames[i].numpy()
                self.current_mask = mask_to_save

                # Write mask to HDF5 file
                mask_seq.write_mask(frame_idx, mask_to_save)

        # Cleanup
        self.show_middle_result = False
        del mask_seq

        return "Done"

    def cancel_tracking(self) -> None:
        """Set flag to cancel tracking."""
        self.cancel = True

    def toggle_display_mode(self) -> None:
        """Toggle the display of intermediate results."""
        self.show_middle_result = not self.show_middle_result

    def get_current_result(self) -> Tuple[Optional[Any], Optional[Any]]:
        """Get current frame and mask.

        Returns:
            Tuple of (current_frame, current_mask)
        """
        return self.current_frame, self.current_mask

__init__(storage_path, project_name, video_source, start_frame, stop_frame, model_type='r50_deaotl')

Initialize the ROI tracker.

Parameters:

Name Type Description Default
storage_path str

Base storage directory

required
project_name str

Name of the project

required
video_source Any

Video source object

required
start_frame int

Starting frame index

required
stop_frame int

Stopping frame index

required
model_type str

Tracking model type (e.g., 'r50_deaotl', 'swinb_deaotl')

'r50_deaotl'
Source code in castle/utils/tracking_manager.py
def __init__(
    self,
    storage_path: str,
    project_name: str,
    video_source: Any,
    start_frame: int,
    stop_frame: int,
    model_type: str = "r50_deaotl",
) -> None:
    """Initialize the ROI tracker.

    Args:
        storage_path: Base storage directory
        project_name: Name of the project
        video_source: Video source object
        start_frame: Starting frame index
        stop_frame: Stopping frame index
        model_type: Tracking model type (e.g., 'r50_deaotl', 'swinb_deaotl')
    """
    self.cancel = False
    self.show_middle_result = False
    self.model_type = model_type

    # Setup paths
    project_path = Path(storage_path) / project_name
    # video_name = video_source.video_name
    self.track_dir = project_path / "track" / video_source.video_name
    self.track_dir.mkdir(parents=True, exist_ok=True)

    # Video parameters
    self.video_source = video_source
    self.start_frame = int(start_frame)
    self.stop_frame = int(stop_frame)
    self.max_memory_length = 30

    # Load reference knowledge from labels
    self.reference_frames = []
    label_list = read_roi_labels(storage_path, project_name) # 移除 video_source.video_name 參數
    self.n_rois = 0

    for label in label_list:
        frame, mask = label["frame"], label["mask"]
        self.reference_frames.append((frame, mask))
        # Update n_rois to be the maximum value found in masks
        self.n_rois = max(self.n_rois, int(np.max(mask)))

    # Current frame and mask for display
    self.current_frame = None
    self.current_mask = None

    # --- Smart Filtering Initialization ---
    self.smart_thresholds = {}
    for (_, mask) in self.reference_frames:
        obj_ids = np.unique(mask)
        for obj_id in obj_ids:
            if obj_id == 0: continue
            area = np.sum(mask == obj_id)
            # If multiple references, we take the one that gives us a reasonable baseline.
            # Here we just blindly set/overwrite, assuming references are consistent.
            # A safer bet is 10% of the reference area.
            self.smart_thresholds[obj_id] = area * 0.1
    print(f"Smart Filtering Thresholds: {self.smart_thresholds}")

cancel_tracking()

Set flag to cancel tracking.

Source code in castle/utils/tracking_manager.py
def cancel_tracking(self) -> None:
    """Set flag to cancel tracking."""
    self.cancel = True

get_current_result()

Get current frame and mask.

Returns:

Type Description
Tuple[Optional[Any], Optional[Any]]

Tuple of (current_frame, current_mask)

Source code in castle/utils/tracking_manager.py
def get_current_result(self) -> Tuple[Optional[Any], Optional[Any]]:
    """Get current frame and mask.

    Returns:
        Tuple of (current_frame, current_mask)
    """
    return self.current_frame, self.current_mask

toggle_display_mode()

Toggle the display of intermediate results.

Source code in castle/utils/tracking_manager.py
def toggle_display_mode(self) -> None:
    """Toggle the display of intermediate results."""
    self.show_middle_result = not self.show_middle_result

track(progress=None, skip_existing=False)

Execute ROI tracking over specified frames using a parallelized DataLoader and batch inference.

Source code in castle/utils/tracking_manager.py
def track(self, progress: Optional[gr.Progress] = None, skip_existing: bool = False) -> str:
    """Execute ROI tracking over specified frames using a parallelized DataLoader and batch inference."""
    time.sleep(0.5)

    # Initialize tracker model and HDF5 writer
    tracker = generate_aot(model_type=self.model_type)
    mask_list_path = self.track_dir / "mask_list.h5"

    # --- Start of new logic: Ensure a clean HDF5 file ---
    if os.path.exists(mask_list_path):
        if skip_existing:
            print(f"Skipping existing tracked file: {mask_list_path}")
            return "Skip"

        try:
            os.remove(mask_list_path)
            print(f"Removed existing HDF5 file: {mask_list_path}")
        except Exception as e:
            print(f"Warning: Could not remove existing HDF5 file {mask_list_path}. Error: {e}")
    # --- End of new logic ---

    mask_seq = H5IO(str(mask_list_path))

    # Write video and ROI configuration
    first_frame = self.video_source[0]
    mask_seq.write_config("n_rois", self.n_rois)
    mask_seq.write_config("total_frames", len(self.video_source))
    mask_seq.write_config("height", first_frame.shape[0])
    mask_seq.write_config("width", first_frame.shape[1])

    # Add all reference ROI frames to tracker's memory
    for frame, mask in self.reference_frames:
        tracker.add_reference_frame(frame, mask, self.n_rois, -1)

    # Determine tracking direction
    delta = 1 if self.start_frame < self.stop_frame else -1
    frame_range = list(range(self.start_frame, self.stop_frame + delta, delta))

    # Initialize DataLoader for batch processing
    # Initialize DataLoader for batch processing
    # Limit num_workers to prevent OOM (20% of CPU cores)
    num_workers = max(1, int(os.cpu_count() * 0.2))
    batch_size = 16  # Process frames in batches
    print(f"DEBUG: Tracking with {num_workers} workers and batch size {batch_size}")

    dataset = TrackingDataset(self.video_source, frame_range, tracker.transform)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )


    # Iterator with progress bar, handling custom notification callback if present
    if progress is not None:
        iterator = progress.tqdm(loader, desc="Tracking frames")
    else:
        iterator = tqdm(loader, desc="Tracking frames")

    for frame_tensors, frame_indices, original_frames in iterator:
        # Check for cancellation flag
        if self.cancel:
            self.show_middle_result = False
            self.cancel = False
            del mask_seq
            return "Cancel"

        # Prepare batch of original sizes
        original_sizes = [frame.shape[:2] for frame in original_frames.numpy()]

        # Perform batch tracking
        mask_batch = tracker.track_batch(frame_tensors, original_sizes=original_sizes)

        # Process and save the batch of masks
        processed_masks = mask_batch.squeeze(1).detach().cpu().numpy().astype(np.uint8)

        for i in range(len(processed_masks)):
            frame_idx = frame_indices[i].item()
            mask_to_save = processed_masks[i]

            # Apply Smart Filtering (Keep Largest Component per class)
            mask_to_save = self._smart_filter(mask_to_save)

            # Update current state for display (with the last frame of the batch)
            self.current_frame = original_frames[i].numpy()
            self.current_mask = mask_to_save

            # Write mask to HDF5 file
            mask_seq.write_mask(frame_idx, mask_to_save)

    # Cleanup
    self.show_middle_result = False
    del mask_seq

    return "Done"

TrackingDataset

Bases: Dataset

Dataset for lazy loading of video frames for tracking.

Source code in castle/utils/tracking_manager.py
class TrackingDataset(Dataset):
    """Dataset for lazy loading of video frames for tracking."""
    def __init__(self, video_source: Any, frame_indices: List[int], transform: Any):
        """
        Initialize the dataset.

        Args:
            video_source: Video source object (e.g., ReadArray)
            frame_indices: List of frame indices to process
            transform: Preprocessing transform to apply to each frame
        """
        self.video_path = video_source.path  # Store path for worker
        self.frame_indices = frame_indices
        self.transform = transform
        self.reader = None # Initialize reader to None for lazy loading in worker

    def __len__(self) -> int:
        return len(self.frame_indices)

    def __getitem__(self, idx: int) -> Tuple[Any, int, Any]:
        if self.reader is None:
            # Each worker gets its own file handle to avoid conflicts
            from .video_io import ReadArray
            self.reader = ReadArray(self.video_path)

        frame_index = self.frame_indices[idx]
        frame = self.reader[frame_index]

        # Apply preprocessing transform
        sample = {'current_img': frame}
        processed_sample = self.transform(sample)
        frame_tensor = processed_sample[0]['current_img']

        # Return the original frame as well for display purposes
        return frame_tensor, frame_index, frame

__init__(video_source, frame_indices, transform)

Initialize the dataset.

Parameters:

Name Type Description Default
video_source Any

Video source object (e.g., ReadArray)

required
frame_indices List[int]

List of frame indices to process

required
transform Any

Preprocessing transform to apply to each frame

required
Source code in castle/utils/tracking_manager.py
def __init__(self, video_source: Any, frame_indices: List[int], transform: Any):
    """
    Initialize the dataset.

    Args:
        video_source: Video source object (e.g., ReadArray)
        frame_indices: List of frame indices to process
        transform: Preprocessing transform to apply to each frame
    """
    self.video_path = video_source.path  # Store path for worker
    self.frame_indices = frame_indices
    self.transform = transform
    self.reader = None # Initialize reader to None for lazy loading in worker

read_roi_labels(storage_path, project_name, video_name=None)

Read all ROI label files for the given project.

Parameters:

Name Type Description Default
storage_path str

Base storage directory

required
project_name str

Name of the project

required
video_name Optional[str]

Optional specific video name to filter labels

None

Returns:

Type Description
List[Dict[str, Any]]

List of dictionaries containing label information with keys: - index: String identifier combining file index and video basename - frame: Frame data - mask: Corresponding mask

Source code in castle/utils/tracking_manager.py
def read_roi_labels(storage_path: str, project_name: str, video_name: Optional[str] = None) -> List[Dict[str, Any]]:
    """Read all ROI label files for the given project.

    Args:
        storage_path: Base storage directory
        project_name: Name of the project
        video_name: Optional specific video name to filter labels

    Returns:
        List of dictionaries containing label information with keys:
            - index: String identifier combining file index and video basename
            - frame: Frame data
            - mask: Corresponding mask
    """
    project_path = Path(storage_path) / project_name
    label_dir = project_path / "label"

    if not label_dir.exists():
        return []

    label_list = []

    # Iterate through all subdirectories in natural sorted order
    for label_folder in natsorted([p for p in label_dir.iterdir() if p.is_dir()]):
        video_basename = label_folder.name

        # Skip if filtering by video name and doesn't match
        if video_name and video_basename != video_name:
            continue

        # Iterate through all .npz files in the folder
        for npz_file in natsorted(list(label_folder.glob("*.npz"))):
            try:
                index = npz_file.stem
                data = np.load(npz_file)

                # Expect keys 'frame' and 'mask'
                if "frame" not in data or "mask" not in data:
                    print(f"Warning: Missing frame or mask in {npz_file}")
                    continue

                frame = data["frame"]
                mask = data["mask"]

                label_list.append({
                    "index": f"{index}, {video_basename}",
                    "frame": frame,
                    "mask": mask,
                })
            except Exception as e:
                print(f"Error loading label file {npz_file}: {str(e)}")
                continue

    return label_list

Feature Extraction

castle.utils.visual_latent_extract

castle/utils/visual_latent_extract.py Wrapper module for backward compatibility. Delegates to castle.core.models.

download_dinov3_ckpt(model_name)

Downloads DINOv3 checkpoint if not exists.

Source code in castle/utils/visual_latent_extract.py
def download_dinov3_ckpt(model_name: str) -> str:
    """
    Downloads DINOv3 checkpoint if not exists.
    """
    from castle.core.config import DEFAULT_CKPT_DIR, CKPT_DINO_IDS, DINOV3_CONSTANTS
    from castle.utils.download import download_with_gdown
    import logging

    logger = logging.getLogger(__name__)

    os.makedirs(DEFAULT_CKPT_DIR, exist_ok=True)

    # Get filename from constants
    filename = DINOV3_CONSTANTS['MODEL_TO_CKPT_FILENAME'].get(model_name, f"{model_name}.pth")
    ckpt_path = DEFAULT_CKPT_DIR / filename

    if ckpt_path.exists():
        return str(ckpt_path)

    logger.info(f"Downloading {model_name} to {ckpt_path}...")

    file_id = CKPT_DINO_IDS.get(model_name)
    if not file_id:
         # Fallback search if model_name mismatch
         pass

    if file_id:
        download_with_gdown(file_id, str(ckpt_path))
    else:
        logger.warning(f"No Google ID found for {model_name}, skipping download.")

    return str(ckpt_path)

generate_dinov2(model_type='dinov2_vitb14', **kwargs)

Wrapper to get DINOv2 encoder from core.

Source code in castle/utils/visual_latent_extract.py
def generate_dinov2(model_type: str = 'dinov2_vitb14', **kwargs) -> VisualEncoder:
    """Wrapper to get DINOv2 encoder from core."""
    return get_visual_encoder(model_type)

generate_dinov3(model_type='dinov3_vitb16', notify_func=None, **kwargs)

Wrapper to get DINOv3 encoder from core.

Source code in castle/utils/visual_latent_extract.py
def generate_dinov3(model_type: str = 'dinov3_vitb16', notify_func=None, **kwargs) -> VisualEncoder:
    """Wrapper to get DINOv3 encoder from core."""
    # notify_func was used for Gradio info, can be ignored or logged
    return get_visual_encoder(model_type)

Latent Explorer

castle.utils.latent_explorer

generate_distinct_color(index, saturation=0.7, value=0.9)

Generate a distinct color using golden ratio for even distribution in HSV space.

This ensures an unlimited number of visually distinct colors can be generated, preventing clusters from becoming grey when the fixed palette is exhausted.

Source code in castle/utils/latent_explorer.py
def generate_distinct_color(index, saturation=0.7, value=0.9):
    """Generate a distinct color using golden ratio for even distribution in HSV space.

    This ensures an unlimited number of visually distinct colors can be generated,
    preventing clusters from becoming grey when the fixed palette is exhausted.
    """
    import colorsys
    golden_ratio = 0.618033988749895
    hue = (index * golden_ratio) % 1.0
    rgb = colorsys.hsv_to_rgb(hue, saturation, value)
    return '#{:02x}{:02x}{:02x}'.format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255))

Video Alignment

castle.utils.video_align

get_roi_closest_point_safe(mask, roi_color)

安全地獲取 ROI 的最接近中心點 如果 ROI 不存在或發生錯誤,返回 None

Parameters:

Name Type Description Default
mask

遮罩影像

required
roi_color

ROI 顏色識別碼

required

Returns:

Type Description

tuple (x, y) 或 None

Source code in castle/utils/video_align.py
def get_roi_closest_point_safe(mask, roi_color):
    """
    安全地獲取 ROI 的最接近中心點
    如果 ROI 不存在或發生錯誤,返回 None

    Args:
        mask: 遮罩影像
        roi_color: ROI 顏色識別碼

    Returns:
        tuple (x, y) 或 None
    """
    try:
        roi_contour = get_contour(mask, roi_color)
        h, w = mask.shape[:2]
        center_x, center_y = (w // 2, h // 2)
        (closest_point_x, closest_point_y), _ = find_closest_point((center_x, center_y), roi_contour)
        return (closest_point_x, closest_point_y)
    except (ValueError, Exception):
        return None

rotate_based_on_point(frame, closest_point)

根據預先計算的最接近點旋轉影像

Parameters:

Name Type Description Default
frame

要旋轉的影像

required
closest_point

(x, y) 座標元組

required

Returns:

Type Description

旋轉後的影像

Source code in castle/utils/video_align.py
def rotate_based_on_point(frame, closest_point):
    """
    根據預先計算的最接近點旋轉影像

    Args:
        frame: 要旋轉的影像
        closest_point: (x, y) 座標元組

    Returns:
        旋轉後的影像
    """
    h, w = frame.shape[:2]
    center_x, center_y = (w // 2, h // 2)
    closest_point_x, closest_point_y = closest_point
    theta = np.arctan2(closest_point_y - center_y, closest_point_x - center_x) * 180. / np.pi
    matrix = cv2.getRotationMatrix2D((center_x, center_y), theta-90, 1.0)
    return cv2.warpAffine(frame, matrix, (w, h))

HDF5 I/O

castle.utils.h5_io


Visualization

castle.utils.plot


Core — Extractor

Build Note

Auto-documentation for castle.core.extractor is unavailable due to import dependencies that are not installed in the docs build environment. See source code directly: castle/core/extractor.py


Core — Data

Build Note

Auto-documentation for castle.core.data is unavailable due to import dependencies. See source: castle/core/data.py


Core — Models

Build Note

Auto-documentation for castle.core.models is unavailable due to import dependencies. See source: castle/core/models.py