# For visualizations
vis_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=12,
                                         shuffle=True,
                                         collate_fn=data.collate_remove_none,
                                         worker_init_fn=data.worker_init_fn)
data_vis = next(iter(vis_loader))

# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Intialize training
npoints = 1000
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
try:
    load_dict = checkpoint_io.load('model.pt')
except FileExistsError:
    load_dict = dict()
epoch_it = load_dict.get('epoch_it', -1)
it = load_dict.get('it', -1)
metric_val_best = load_dict.get('loss_val_best',
                                -model_selection_sign * np.inf)

# Hack because of previous bug in code
# TODO: remove, because shouldn't be necessary
if metric_val_best == np.inf or metric_val_best == -np.inf:
    metric_val_best = -model_selection_sign * np.inf
示例#2
0
npoints = 1000
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-08)
# optimizer = tf.keras.optimizers.SGD(learning_rate=1e-4, momentum=0.9)

checkpoint_io = CheckpointIO(model, optimizer, model_selection_sign, out_dir)

try:
    checkpoint_io.load('model')
except FileExistsError:
    print("start from scratch")

epoch_it = checkpoint_io.ckpt.epoch_it
it = checkpoint_io.ckpt.it
metric_val_best = checkpoint_io.ckpt.metric_val_best

trainer = config.get_trainer(model, optimizer, cfg)

# Hack because of previous bug in code
if metric_val_best == np.inf or metric_val_best == -np.inf:
    metric_val_best = -model_selection_sign * np.inf

print('Current best validation metric (%s): %.8f' %
      (model_selection_metric, metric_val_best))

# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']

# log
        shuffle=True,
        collate_fn=data.collate_remove_none,
    )
    data_viz = next(iter(val_loader))
    model = config.get_model(cfg,
                             device=device,
                             len_dataset=len(train_dataset))

    # Initialize training
    optimizer = optim.Adam(model.parameters(), lr=lr)

    generator = config.get_generator(model, cfg, device=device)

    trainer = config.get_trainer(model,
                                 optimizer,
                                 cfg,
                                 device=device,
                                 generator=generator)
    checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
    try:
        load_dict = checkpoint_io.load('model.pt', device=device)
    except FileExistsError:
        load_dict = dict()

    epoch_it = load_dict.get('epoch_it', -1)
    it = load_dict.get('it', -1)
    metric_val_best = load_dict.get('loss_val_best',
                                    -model_selection_sign * np.inf)

    if metric_val_best == np.inf or metric_val_best == -np.inf:
        metric_val_best = -model_selection_sign * np.inf
示例#4
0
# Loader
dataloader = dataset.loader()

model = config.get_model(cfg, dataset=dataset)
dummy_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-08)

checkpoint_io = CheckpointIO(model, dummy_optimizer, checkpoint_dir=out_dir)

try:
    checkpoint_io.load(cfg['test']['model_file'])
except FileExistsError:
    print('Model file does not exist. Exiting.')
    exit()

# Trainer
trainer = config.get_trainer(model, None, cfg)

eval_dicts = []
print('Evaluating networks...')

# Handle each dataset separately
for it, data in enumerate(tqdm(dataloader)):
    if data is None:
        print('Invalid data.')
        continue
    # Get index etc.
    # idx = data['idx'].item()
    idx = it

    try:
        model_dict = dataset.get_model_dict(idx)
out_file = os.path.join(out_dir, 'eval_full.pkl')
out_file_class = os.path.join(out_dir, 'eval.csv')

# Dataset
dataset = config.get_dataset('test', cfg, return_idx=True)
model = config.get_model(cfg, device=device, dataset=dataset)

checkpoint_io = CheckpointIO(out_dir, model=model)
try:
    checkpoint_io.load(cfg['test']['model_file'])
except FileExistsError:
    print('Model file does not exist. Exiting.')
    exit()

# Trainer
trainer = config.get_trainer(model, None, cfg, device=device)

# Print model
nparameters = sum(p.numel() for p in model.parameters())
print(model)
print('Total number of parameters: %d' % nparameters)

# Evaluate
model.eval()

eval_dicts = []
print('Evaluating networks...')

test_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=1,
                                          shuffle=False,