コード例 #1
0
ファイル: test.py プロジェクト: skasai5296/dpc
def test(loader, model, criterion, device, CONFIG):
    test_timer = Timer()
    metrics = [AverageMeter("XELoss"), AverageMeter("Accuracy (%)")]
    global_metrics = [AverageMeter("XELoss"), AverageMeter("Accuracy (%)")]
    model.eval()
    for it, data in enumerate(loader):
        clip = data["clip"].to(device)
        label = data["label"].to(device)
        if it == 1 and torch.cuda.is_available():
            subprocess.run(["nvidia-smi"])

        with torch.no_grad():
            out = model(clip)
            loss, lossdict = criterion(out, label)

        for metric in metrics:
            metric.update(lossdict[metric.name])
        for metric in global_metrics:
            metric.update(lossdict[metric.name])
        if it % 10 == 9:
            metricstr = " | ".join([f"test {metric}" for metric in metrics])
            print(
                f"test | {test_timer} | iter {it+1:06d}/{len(loader):06d} | "
                f"{metricstr}",
                flush=True,
            )
            for metric in metrics:
                metric.reset()
    metric = global_metrics[-1]
    if CONFIG.use_wandb:
        wandb.log({f"test {metric.name}": metric.avg}, commit=False)
    return metric.avg
コード例 #2
0
def validate(loader, model, criterion, device, CONFIG, epoch):
    val_timer = Timer()
    metrics = [
        AverageMeter("XELoss"),
        AverageMeter("MSELoss"),
        AverageMeter("Accuracy (%)")
    ]
    global_metrics = [
        AverageMeter("XELoss"),
        AverageMeter("MSELoss"),
        AverageMeter("Accuracy (%)")
    ]
    if CONFIG.model in ("DPC"):
        metrics.pop(1)
        global_metrics.pop(1)
    model.eval()
    for it, data in enumerate(loader):

        clip = data["clip"].to(device)
        if it == 1 and torch.cuda.is_available():
            subprocess.run(["nvidia-smi"])

        with torch.no_grad():
            output = model(clip)
            loss, lossdict = criterion(*output)

        for metric in metrics:
            metric.update(lossdict[metric.name])
        for metric in global_metrics:
            metric.update(lossdict[metric.name])
        if it % 10 == 9:
            metricstr = " | ".join(
                [f"validation {metric}" for metric in metrics])
            print(
                f"epoch {epoch:03d}/{CONFIG.max_epoch:03d} | valid | "
                f"{val_timer} | iter {it+1:06d}/{len(loader):06d} | "
                f"{metricstr}",
                flush=True,
            )
            for metric in metrics:
                metric.reset()
        # validating for 100 steps is enough
        if it == 100:
            break
    if CONFIG.use_wandb:
        for metric in global_metrics:
            wandb.log({f"epoch {metric.name}": metric.avg}, commit=False)
    return global_metrics[-1].avg
コード例 #3
0
def train_epoch(loader, model, optimizer, criterion, device, CONFIG, epoch):
    train_timer = Timer()
    metrics = [
        AverageMeter("XELoss"),
        AverageMeter("MSELoss"),
        AverageMeter("Accuracy (%)")
    ]
    if CONFIG.model in ("DPC"):
        metrics.pop(1)
    model.train()
    for it, data in enumerate(loader):

        clip = data["clip"].to(device)
        if it == 1 and torch.cuda.is_available():
            subprocess.run(["nvidia-smi"])

        optimizer.zero_grad()
        output = model(clip)
        loss, lossdict = criterion(*output)

        for metric in metrics:
            metric.update(lossdict[metric.name])
        loss.backward()
        if CONFIG.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=CONFIG.grad_clip)
        optimizer.step()
        if it % 10 == 9:
            metricstr = " | ".join([f"train {metric}" for metric in metrics])
            print(
                f"epoch {epoch:03d}/{CONFIG.max_epoch:03d} | train | "
                f"{train_timer} | iter {it+1:06d}/{len(loader):06d} | "
                f"{metricstr}",
                flush=True,
            )
            if CONFIG.use_wandb:
                for metric in metrics:
                    wandb.log({f"train {metric.name}": metric.avg},
                              commit=False)
                wandb.log({"iteration": it + (epoch - 1) * len(loader)})
            for metric in metrics:
                metric.reset()
コード例 #4
0
ファイル: extract.py プロジェクト: skasai5296/dpc
def extract_features(dataset, model, device, CONFIG):
    test_timer = Timer()
    model.eval()
    for it, data in enumerate(dataset):
        clip = data["clip"].to(device)
        id = data["id"]
        if it == 1 and torch.cuda.is_available():
            subprocess.run(["nvidia-smi"])

        # (T/n_clip, n_clip, C, clip_len, H, W)
        duration = data["duration"]
        with torch.no_grad():
            out = model(clip, flag="extract")
        # (T, 7 * 7, D)
        out = out.reshape(-1, 7 * 7, CONFIG.hidden_size)
        out = out.mean(1)[:duration]
        print(out.size())
        torch.save(out, os.path.join(dataset.root_path, f"feature/{id}.pth"))

        if it % 10 == 9:
            print(
                f"extracting features | {test_timer} | iter {it+1:06d}/{len(dataset):06d} | ",
                flush=True,
            )
コード例 #5
0
ファイル: test.py プロジェクト: skasai5296/dpc
            metricstr = " | ".join([f"test {metric}" for metric in metrics])
            print(
                f"test | {test_timer} | iter {it+1:06d}/{len(loader):06d} | "
                f"{metricstr}",
                flush=True,
            )
            for metric in metrics:
                metric.reset()
    metric = global_metrics[-1]
    if CONFIG.use_wandb:
        wandb.log({f"test {metric.name}": metric.avg}, commit=False)
    return metric.avg


if __name__ == "__main__":
    global_timer = Timer()

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default="cfg/default.yml",
        help="path to configuration yml file",
    )
    opt = parser.parse_args()
    print(f"loading configuration from {opt.config}")
    CONFIG = Dict(yaml.safe_load(open(opt.config)))
    print("CONFIGURATIONS:")
    pprint(CONFIG)

    CONFIG.new_config_name = f"ft_{CONFIG.dataset}_{CONFIG.config_name}"