Module recsys.model_selection
Expand source code Browse git
from collections import defaultdict
from typing import Generator, List
import numpy as np
from recsys.types import BatchType
def round_robin_kfold(batches: List[BatchType],
n_splits: int = 10) -> Generator:
"""
Round robin k-fold cross validation.
Useful when batches are sorted by sequence length, to keep the training
set and validation set as balanced as possible in sequence length.
Train indices are shuffled to try to reduce the bias in the gradient updates.
"""
np.random.seed(42)
n = len(batches)
groups = defaultdict(list)
group_id = 0
for i in range(n):
groups[group_id % n_splits].append(i)
group_id += 1
for i in range(n_splits):
train_index = np.concatenate([group for group_id, group in groups.items()
if group_id != i])
valid_index = np.array(groups[i])
np.random.shuffle(train_index)
yield train_index, valid_index
Functions
def round_robin_kfold(batches: List[Dict[str, torch.Tensor]], n_splits: int = 10) ‑> Generator-
Round robin k-fold cross validation. Useful when batches are sorted by sequence length, to keep the training set and validation set as balanced as possible in sequence length. Train indices are shuffled to try to reduce the bias in the gradient updates.
Expand source code Browse git
def round_robin_kfold(batches: List[BatchType], n_splits: int = 10) -> Generator: """ Round robin k-fold cross validation. Useful when batches are sorted by sequence length, to keep the training set and validation set as balanced as possible in sequence length. Train indices are shuffled to try to reduce the bias in the gradient updates. """ np.random.seed(42) n = len(batches) groups = defaultdict(list) group_id = 0 for i in range(n): groups[group_id % n_splits].append(i) group_id += 1 for i in range(n_splits): train_index = np.concatenate([group for group_id, group in groups.items() if group_id != i]) valid_index = np.array(groups[i]) np.random.shuffle(train_index) yield train_index, valid_index