Module recsys.model_train
Expand source code Browse git
import logging
import os
import sys
import time
from datetime import timedelta
from typing import Tuple, Dict, List
import pandas as pd
import torch
from recsys.model_selection import round_robin_kfold
from recsys import config
if 'ipykernel' in sys.modules:
from tqdm.notebook import tqdm
else:
from tqdm import tqdm
from recsys.config import DEVICE
from recsys.model import BookingNet
from recsys.paths import get_path, get_model_arch_path
from recsys.types import BatchType, OptimizerType
def get_model_metrics(batch: BatchType,
seq_len: torch.Tensor,
city_scores: torch.Tensor) -> Tuple:
"""
Get model metrics, e.g. accuracy@1, accuracy@4.
"""
bs, ts = batch['city_id'].shape
predicted_cities = (city_scores
.argmax(1)
.view(-1, ts)
.gather(1, (seq_len.unsqueeze(1) - 1).to(DEVICE))
.squeeze(1) + 1)
predicted_cities_top_k = (torch.topk(city_scores, 4, dim=1).indices.view(bs, -1, 4)
.gather(1, (torch.cat([seq_len.unsqueeze(1)] * 4, axis=1).view(-1, 1, 4) - 1)
.to(DEVICE))).squeeze(1) + 1
hits_at_1 = (predicted_cities == batch['last_city']).float().sum()
hits_at_k = torch.sum(predicted_cities_top_k.eq(batch['last_city'].unsqueeze(1)), dim=1).float().sum()
return hits_at_1, hits_at_k
def train_step(model: BookingNet,
batch: List[BatchType]) -> Dict:
"""
Training step, including loss evaluation and backprop.
"""
batch, seq_len = batch
model.optimizer.zero_grad(set_to_none=True)
city_scores = model(batch, seq_len)
city_scores = city_scores.view(-1, 39901)
loss = model.get_loss(city_scores,
batch,
seq_len,
device=DEVICE)
loss.backward()
model.optimizer.step()
return {
'train_loss': loss.item()
}
def validation_step(model: BookingNet,
batch: BatchType) -> Dict:
"""
Validation step, including metric computation for batch.
"""
batch, seq_len = batch
city_scores = model(batch, seq_len)
city_scores = city_scores.view(-1, 39901)
loss = model.get_loss(city_scores,
batch,
seq_len)
hits_at_1, hits_at_k = get_model_metrics(batch, seq_len, city_scores)
obs = len(batch['city_id'])
return {
'valid_loss': loss.item(),
'hits_at_1': hits_at_1.item(),
'hits_at_k': hits_at_k.item(),
'obs': obs
}
def train_for_all_batches(model: BookingNet,
train_batches: List[BatchType]) -> Dict:
"""
Train model on all given batches.
"""
current_time = time.time()
train_loss = 0
model.train()
for batch in train_batches:
train_step_result = train_step(model, batch)
train_loss += train_step_result['train_loss']
train_loss /= len(train_batches) # loss per batch
ellapsed_time = timedelta(seconds=int(time.time() - current_time))
return {
'train_loss': train_loss,
'ellapsed_time': ellapsed_time,
}
def valid_for_all_batches(model: BookingNet,
valid_batches: List[BatchType]) -> Dict:
"""
Run validation set metrics for all batches.
"""
current_time = time.time()
valid_result = {
'valid_loss': 0,
'hits_at_1': 0,
'hits_at_k': 0,
'obs': 0
}
model.eval()
with torch.no_grad():
for batch in valid_batches:
batch_result = validation_step(model, batch)
for key in valid_result.keys():
valid_result[key] += batch_result[key]
ellapsed_time = timedelta(seconds=int(time.time() - current_time))
return {
'valid_loss': valid_result['valid_loss'] / len(valid_batches), # loss per batch
'accuracy@1': valid_result['hits_at_1'] / valid_result['obs'],
'accuracy@4': valid_result['hits_at_k'] / valid_result['obs'],
'ellapsed_time_valid': ellapsed_time
}
def model_checkpoint_exists(model_hash: str,
fold: int) -> bool:
"""
Returns `true` if the model checkpoint given by the path exists, `false` otherwise.
"""
ckpt_path = get_path(dirs=["models", model_hash],
filename=f"fold_{fold}_best_accuracy_at_k",
format="pt")
return os.path.exists(ckpt_path)
def train_model(model: BookingNet,
train_batches: List[BatchType],
valid_batches: List[BatchType],
epochs: int = 50,
fold: int = 0,
min_epochs_to_save: int = 20,
verbose: bool = True) -> pd.DataFrame:
"""
Train model from batches and save checkpoints of best models by accuracy.
"""
epoch_report = {}
best_accuracy_at_k = 0
for epoch in tqdm(range(epochs)):
train_report = train_for_all_batches(model, train_batches)
valid_report = valid_for_all_batches(model, valid_batches)
if epoch >= min_epochs_to_save and valid_report['accuracy@4'] > best_accuracy_at_k:
best_accuracy_at_k = valid_report['accuracy@4']
torch.save(model.state_dict(), get_path(dirs=["models", model.hash],
filename=f"fold_{fold}_best_accuracy_at_k",
format="pt"))
r = dict(train_report)
r.update(valid_report)
epoch_report[epoch] = r
if verbose:
epoch_str = [f"Epoch: {epoch}",
f"train loss: {r['train_loss']:.4f}",
f"valid loss: {r['valid_loss']:.4f}",
f"accuracy@1: {r['accuracy@1']:.4f}",
f"accuracy@4: {r['accuracy@4']:.4f}",
f"time: {r['ellapsed_time']}"]
epoch_str = ', '.join(epoch_str)
logging.info(epoch_str)
# save report
pd.DataFrame(epoch_report).T.to_csv(get_path(dirs=["reports", model.hash],
hash=model.hash,
fold=fold,
format='csv'))
with open(get_model_arch_path(model.hash), "w") as fhandle:
fhandle.write(str(model))
return pd.DataFrame(epoch_report).T
def train_model_for_folds(dataset_batches: List[BatchType],
train_set: pd.DataFrame,
model_configuration: Dict,
n_models: int = config.N_SPLITS,
min_epochs_to_save: int = 25,
skip_checkpoint=False) -> str:
"""
Train `n_models` given a model configuration, returning the model hash.
"""
for fold, (train_index, valid_index) in enumerate(round_robin_kfold(dataset_batches,
n_splits=config.N_SPLITS)):
if fold >= n_models:
break
model = BookingNet(**model_configuration).to(config.DEVICE)
model.set_optimizer(optimizer_type=OptimizerType.ADAMW)
model.set_entropy_weights(train_set)
model_hash = model.hash
if not skip_checkpoint and model_checkpoint_exists(model.hash, fold):
continue
train_batches = dataset_batches[train_index]
valid_batches = dataset_batches[valid_index]
# valid_batches = filter_batches_by_length(valid_batches)
logging.info(f"Training model {model.hash} for fold {fold}")
train_model(model,
train_batches,
valid_batches,
epochs=config.EPOCHS,
min_epochs_to_save=min_epochs_to_save,
fold=fold)
# Empty CUDA memory
del model
torch.cuda.empty_cache()
return model_hash
Functions
def get_model_metrics(batch: Dict[str, torch.Tensor], seq_len: torch.Tensor, city_scores: torch.Tensor) ‑> Tuple-
Get model metrics, e.g. accuracy@1, accuracy@4.
Expand source code Browse git
def get_model_metrics(batch: BatchType, seq_len: torch.Tensor, city_scores: torch.Tensor) -> Tuple: """ Get model metrics, e.g. accuracy@1, accuracy@4. """ bs, ts = batch['city_id'].shape predicted_cities = (city_scores .argmax(1) .view(-1, ts) .gather(1, (seq_len.unsqueeze(1) - 1).to(DEVICE)) .squeeze(1) + 1) predicted_cities_top_k = (torch.topk(city_scores, 4, dim=1).indices.view(bs, -1, 4) .gather(1, (torch.cat([seq_len.unsqueeze(1)] * 4, axis=1).view(-1, 1, 4) - 1) .to(DEVICE))).squeeze(1) + 1 hits_at_1 = (predicted_cities == batch['last_city']).float().sum() hits_at_k = torch.sum(predicted_cities_top_k.eq(batch['last_city'].unsqueeze(1)), dim=1).float().sum() return hits_at_1, hits_at_k def model_checkpoint_exists(model_hash: str, fold: int) ‑> bool-
Returns
trueif the model checkpoint given by the path exists,falseotherwise.Expand source code Browse git
def model_checkpoint_exists(model_hash: str, fold: int) -> bool: """ Returns `true` if the model checkpoint given by the path exists, `false` otherwise. """ ckpt_path = get_path(dirs=["models", model_hash], filename=f"fold_{fold}_best_accuracy_at_k", format="pt") return os.path.exists(ckpt_path) def train_for_all_batches(model: BookingNet, train_batches: List[Dict[str, torch.Tensor]]) ‑> Dict-
Train model on all given batches.
Expand source code Browse git
def train_for_all_batches(model: BookingNet, train_batches: List[BatchType]) -> Dict: """ Train model on all given batches. """ current_time = time.time() train_loss = 0 model.train() for batch in train_batches: train_step_result = train_step(model, batch) train_loss += train_step_result['train_loss'] train_loss /= len(train_batches) # loss per batch ellapsed_time = timedelta(seconds=int(time.time() - current_time)) return { 'train_loss': train_loss, 'ellapsed_time': ellapsed_time, } def train_model(model: BookingNet, train_batches: List[Dict[str, torch.Tensor]], valid_batches: List[Dict[str, torch.Tensor]], epochs: int = 50, fold: int = 0, min_epochs_to_save: int = 20, verbose: bool = True) ‑> pandas.core.frame.DataFrame-
Train model from batches and save checkpoints of best models by accuracy.
Expand source code Browse git
def train_model(model: BookingNet, train_batches: List[BatchType], valid_batches: List[BatchType], epochs: int = 50, fold: int = 0, min_epochs_to_save: int = 20, verbose: bool = True) -> pd.DataFrame: """ Train model from batches and save checkpoints of best models by accuracy. """ epoch_report = {} best_accuracy_at_k = 0 for epoch in tqdm(range(epochs)): train_report = train_for_all_batches(model, train_batches) valid_report = valid_for_all_batches(model, valid_batches) if epoch >= min_epochs_to_save and valid_report['accuracy@4'] > best_accuracy_at_k: best_accuracy_at_k = valid_report['accuracy@4'] torch.save(model.state_dict(), get_path(dirs=["models", model.hash], filename=f"fold_{fold}_best_accuracy_at_k", format="pt")) r = dict(train_report) r.update(valid_report) epoch_report[epoch] = r if verbose: epoch_str = [f"Epoch: {epoch}", f"train loss: {r['train_loss']:.4f}", f"valid loss: {r['valid_loss']:.4f}", f"accuracy@1: {r['accuracy@1']:.4f}", f"accuracy@4: {r['accuracy@4']:.4f}", f"time: {r['ellapsed_time']}"] epoch_str = ', '.join(epoch_str) logging.info(epoch_str) # save report pd.DataFrame(epoch_report).T.to_csv(get_path(dirs=["reports", model.hash], hash=model.hash, fold=fold, format='csv')) with open(get_model_arch_path(model.hash), "w") as fhandle: fhandle.write(str(model)) return pd.DataFrame(epoch_report).T def train_model_for_folds(dataset_batches: List[Dict[str, torch.Tensor]], train_set: pandas.core.frame.DataFrame, model_configuration: Dict, n_models: int = 10, min_epochs_to_save: int = 25, skip_checkpoint=False) ‑> str-
Train
n_modelsgiven a model configuration, returning the model hash.Expand source code Browse git
def train_model_for_folds(dataset_batches: List[BatchType], train_set: pd.DataFrame, model_configuration: Dict, n_models: int = config.N_SPLITS, min_epochs_to_save: int = 25, skip_checkpoint=False) -> str: """ Train `n_models` given a model configuration, returning the model hash. """ for fold, (train_index, valid_index) in enumerate(round_robin_kfold(dataset_batches, n_splits=config.N_SPLITS)): if fold >= n_models: break model = BookingNet(**model_configuration).to(config.DEVICE) model.set_optimizer(optimizer_type=OptimizerType.ADAMW) model.set_entropy_weights(train_set) model_hash = model.hash if not skip_checkpoint and model_checkpoint_exists(model.hash, fold): continue train_batches = dataset_batches[train_index] valid_batches = dataset_batches[valid_index] # valid_batches = filter_batches_by_length(valid_batches) logging.info(f"Training model {model.hash} for fold {fold}") train_model(model, train_batches, valid_batches, epochs=config.EPOCHS, min_epochs_to_save=min_epochs_to_save, fold=fold) # Empty CUDA memory del model torch.cuda.empty_cache() return model_hash def train_step(model: BookingNet, batch: List[Dict[str, torch.Tensor]]) ‑> Dict-
Training step, including loss evaluation and backprop.
Expand source code Browse git
def train_step(model: BookingNet, batch: List[BatchType]) -> Dict: """ Training step, including loss evaluation and backprop. """ batch, seq_len = batch model.optimizer.zero_grad(set_to_none=True) city_scores = model(batch, seq_len) city_scores = city_scores.view(-1, 39901) loss = model.get_loss(city_scores, batch, seq_len, device=DEVICE) loss.backward() model.optimizer.step() return { 'train_loss': loss.item() } def valid_for_all_batches(model: BookingNet, valid_batches: List[Dict[str, torch.Tensor]]) ‑> Dict-
Run validation set metrics for all batches.
Expand source code Browse git
def valid_for_all_batches(model: BookingNet, valid_batches: List[BatchType]) -> Dict: """ Run validation set metrics for all batches. """ current_time = time.time() valid_result = { 'valid_loss': 0, 'hits_at_1': 0, 'hits_at_k': 0, 'obs': 0 } model.eval() with torch.no_grad(): for batch in valid_batches: batch_result = validation_step(model, batch) for key in valid_result.keys(): valid_result[key] += batch_result[key] ellapsed_time = timedelta(seconds=int(time.time() - current_time)) return { 'valid_loss': valid_result['valid_loss'] / len(valid_batches), # loss per batch 'accuracy@1': valid_result['hits_at_1'] / valid_result['obs'], 'accuracy@4': valid_result['hits_at_k'] / valid_result['obs'], 'ellapsed_time_valid': ellapsed_time } def validation_step(model: BookingNet, batch: Dict[str, torch.Tensor]) ‑> Dict-
Validation step, including metric computation for batch.
Expand source code Browse git
def validation_step(model: BookingNet, batch: BatchType) -> Dict: """ Validation step, including metric computation for batch. """ batch, seq_len = batch city_scores = model(batch, seq_len) city_scores = city_scores.view(-1, 39901) loss = model.get_loss(city_scores, batch, seq_len) hits_at_1, hits_at_k = get_model_metrics(batch, seq_len, city_scores) obs = len(batch['city_id']) return { 'valid_loss': loss.item(), 'hits_at_1': hits_at_1.item(), 'hits_at_k': hits_at_k.item(), 'obs': obs }