forked from jphdotam/A4C3D
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
94 lines (75 loc) · 3.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import wandb
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from lib.vis import vis_mse
from lib.models import load_model
from lib.config import load_config
from lib.dataset import E32Dataset
from lib.losses import load_criterion
from lib.optimizers import load_optimizer
from lib.transforms import load_transforms
from lib.training import cycle, save_state
import torch.distributed
CONFIG = "/home/james/a4c3d/experiments/008.yaml"
def main():
cfg = load_config(CONFIG)
# distributed settings
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--ngpu', type=int, default=4)
args = parser.parse_args()
if cfg['training']['data_parallel'] == 'distributed':
distributed = True
local_rank = args.local_rank
torch.cuda.set_device(local_rank)
world_size = args.ngpu
torch.distributed.init_process_group('nccl', init_method="tcp://localhost:16534", world_size=world_size, rank=local_rank)
else:
distributed = False
local_rank = None
world_size = None
# settings
bs_train, bs_test, n_workers = cfg['training']['batch_size_train'], cfg['training']['batch_size_test'], cfg['training']['n_workers']
n_epochs = cfg['training']['n_epochs']
transforms_train, transforms_test = load_transforms(cfg)
# data
ds_train = E32Dataset(cfg, cfg['paths']['data_train'], 'train', transforms=transforms_train)
ds_test = E32Dataset(cfg, cfg['paths']['data_test'], 'test', transforms=transforms_test)
sampler_train = DistributedSampler(ds_train, num_replicas=world_size, rank=local_rank) if distributed else None
sampler_test = DistributedSampler(ds_test, num_replicas=world_size, rank=local_rank) if distributed else None
dl_train = DataLoader(ds_train, bs_train, shuffle=False if distributed else True, num_workers=n_workers, pin_memory=False, sampler=sampler_train)
dl_test = DataLoader(ds_test, bs_test, shuffle=False, num_workers=n_workers, pin_memory=False, sampler=sampler_test)
# model
model, starting_epoch, state = load_model(cfg, local_rank)
optimizer, scheduler = load_optimizer(model, cfg, state, steps_per_epoch=(len(dl_train)))
train_criterion, test_criterion = load_criterion(cfg)
# WandB
if not local_rank:
wandb.init(project="a4c3d", config=cfg, notes=cfg.get("description", None))
wandb.save("*.mp4") # Write MP4 files immediately to WandB
wandb.watch(model)
# training
best_loss, best_path, last_save_path = 1e10, None, None
for epoch in range(starting_epoch, n_epochs + 1):
if local_rank == 0:
print(f"\nEpoch {epoch} of {n_epochs}")
# Cycle
train_loss = cycle('train', model, dl_train, epoch, train_criterion, optimizer, cfg, scheduler, local_rank=local_rank)
test_loss = cycle('test', model, dl_test, epoch, test_criterion, optimizer, cfg, scheduler, local_rank=local_rank)
# Save state if required
if local_rank == 0:
model_weights = model.module.state_dict() if cfg['training']['data_parallel'] else model.state_dict()
state = {'epoch': epoch + 1,
'model': model_weights,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler}
save_name = f"{epoch}_{test_loss:.05f}.pt"
best_loss, last_save_path = save_state(state, save_name, test_loss, best_loss, cfg, last_save_path, lowest_best=True)
# Vis seg
vis_mse(ds_test, model, epoch, cfg)
if local_rank == 0:
save_name = f"FINAL_{epoch}_{test_loss:.05f}.pt"
save_state(state, save_name, test_loss, best_loss, cfg, last_save_path, force=True)
if __name__ == '__main__':
main()