Skip to content

[Fix] Speed up "--resume" #1548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import bisect
import logging
import time
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -156,15 +156,26 @@ class _InfiniteDataloaderIterator:

def __init__(self, dataloader: DataLoader) -> None:
self._dataloader = dataloader
self._iterator = iter(self._dataloader)
self._iterator: Any = iter(self._dataloader)
self._epoch = 0

def __iter__(self):
return self

def __next__(self) -> Sequence[dict]:
return self._next_data()

def skip_iter(self, iter: int) -> None:
for _ in range(iter):
self._next_data(skip_loading=True)

def _next_data(self, skip_loading=False) -> Any:
data = None
try:
data = next(self._iterator)
if skip_loading:
self._iterator._next_index()
else:
data = next(self._iterator)
except StopIteration:
print_log(
'Reach the end of the dataloader, it will be '
Expand All @@ -188,8 +199,14 @@ def __next__(self) -> Sequence[dict]:
# attributes.
self._dataloader.batch_sampler.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self._iterator = iter(self._dataloader)
data = next(self._iterator)

bypass_mypy_checking: Any = iter(self._dataloader)
self._iterator = bypass_mypy_checking

if skip_loading:
bypass_mypy_checking._next_index()
else:
data = next(self._iterator)
return data


Expand Down Expand Up @@ -280,8 +297,7 @@ def run(self) -> None:
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
next(self.dataloader_iterator)
self.dataloader_iterator.skip_iter(self._iter)
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()

Expand Down