def test(loader, model, criterion, device, CONFIG): test_timer = Timer() metrics = [AverageMeter("XELoss"), AverageMeter("Accuracy (%)")] global_metrics = [AverageMeter("XELoss"), AverageMeter("Accuracy (%)")] model.eval() for it, data in enumerate(loader): clip = data["clip"].to(device) label = data["label"].to(device) if it == 1 and torch.cuda.is_available(): subprocess.run(["nvidia-smi"]) with torch.no_grad(): out = model(clip) loss, lossdict = criterion(out, label) for metric in metrics: metric.update(lossdict[metric.name]) for metric in global_metrics: metric.update(lossdict[metric.name]) if it % 10 == 9: metricstr = " | ".join([f"test {metric}" for metric in metrics]) print( f"test | {test_timer} | iter {it+1:06d}/{len(loader):06d} | " f"{metricstr}", flush=True, ) for metric in metrics: metric.reset() metric = global_metrics[-1] if CONFIG.use_wandb: wandb.log({f"test {metric.name}": metric.avg}, commit=False) return metric.avg
def validate(loader, model, criterion, device, CONFIG, epoch): val_timer = Timer() metrics = [ AverageMeter("XELoss"), AverageMeter("MSELoss"), AverageMeter("Accuracy (%)") ] global_metrics = [ AverageMeter("XELoss"), AverageMeter("MSELoss"), AverageMeter("Accuracy (%)") ] if CONFIG.model in ("DPC"): metrics.pop(1) global_metrics.pop(1) model.eval() for it, data in enumerate(loader): clip = data["clip"].to(device) if it == 1 and torch.cuda.is_available(): subprocess.run(["nvidia-smi"]) with torch.no_grad(): output = model(clip) loss, lossdict = criterion(*output) for metric in metrics: metric.update(lossdict[metric.name]) for metric in global_metrics: metric.update(lossdict[metric.name]) if it % 10 == 9: metricstr = " | ".join( [f"validation {metric}" for metric in metrics]) print( f"epoch {epoch:03d}/{CONFIG.max_epoch:03d} | valid | " f"{val_timer} | iter {it+1:06d}/{len(loader):06d} | " f"{metricstr}", flush=True, ) for metric in metrics: metric.reset() # validating for 100 steps is enough if it == 100: break if CONFIG.use_wandb: for metric in global_metrics: wandb.log({f"epoch {metric.name}": metric.avg}, commit=False) return global_metrics[-1].avg
def train_epoch(loader, model, optimizer, criterion, device, CONFIG, epoch): train_timer = Timer() metrics = [ AverageMeter("XELoss"), AverageMeter("MSELoss"), AverageMeter("Accuracy (%)") ] if CONFIG.model in ("DPC"): metrics.pop(1) model.train() for it, data in enumerate(loader): clip = data["clip"].to(device) if it == 1 and torch.cuda.is_available(): subprocess.run(["nvidia-smi"]) optimizer.zero_grad() output = model(clip) loss, lossdict = criterion(*output) for metric in metrics: metric.update(lossdict[metric.name]) loss.backward() if CONFIG.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=CONFIG.grad_clip) optimizer.step() if it % 10 == 9: metricstr = " | ".join([f"train {metric}" for metric in metrics]) print( f"epoch {epoch:03d}/{CONFIG.max_epoch:03d} | train | " f"{train_timer} | iter {it+1:06d}/{len(loader):06d} | " f"{metricstr}", flush=True, ) if CONFIG.use_wandb: for metric in metrics: wandb.log({f"train {metric.name}": metric.avg}, commit=False) wandb.log({"iteration": it + (epoch - 1) * len(loader)}) for metric in metrics: metric.reset()
def extract_features(dataset, model, device, CONFIG): test_timer = Timer() model.eval() for it, data in enumerate(dataset): clip = data["clip"].to(device) id = data["id"] if it == 1 and torch.cuda.is_available(): subprocess.run(["nvidia-smi"]) # (T/n_clip, n_clip, C, clip_len, H, W) duration = data["duration"] with torch.no_grad(): out = model(clip, flag="extract") # (T, 7 * 7, D) out = out.reshape(-1, 7 * 7, CONFIG.hidden_size) out = out.mean(1)[:duration] print(out.size()) torch.save(out, os.path.join(dataset.root_path, f"feature/{id}.pth")) if it % 10 == 9: print( f"extracting features | {test_timer} | iter {it+1:06d}/{len(dataset):06d} | ", flush=True, )
metricstr = " | ".join([f"test {metric}" for metric in metrics]) print( f"test | {test_timer} | iter {it+1:06d}/{len(loader):06d} | " f"{metricstr}", flush=True, ) for metric in metrics: metric.reset() metric = global_metrics[-1] if CONFIG.use_wandb: wandb.log({f"test {metric.name}": metric.avg}, commit=False) return metric.avg if __name__ == "__main__": global_timer = Timer() parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, default="cfg/default.yml", help="path to configuration yml file", ) opt = parser.parse_args() print(f"loading configuration from {opt.config}") CONFIG = Dict(yaml.safe_load(open(opt.config))) print("CONFIGURATIONS:") pprint(CONFIG) CONFIG.new_config_name = f"ft_{CONFIG.dataset}_{CONFIG.config_name}"