save_root=save_root,
    exp_name=args.exp_name,
    example_input=example_input,
    enable_save_trace=enable_save_trace,
    schedulers={'lr': lr_sched},
    valid_metrics=valid_metrics,
    preview_batch=preview_batch,
    preview_interval=5,
    inference_kwargs=inference_kwargs,
    hparams=hparams,
    # enable_videos=True,  # Uncomment to enable videos in tensorboard
    out_channels=out_channels,
    ipython_shell=args.ipython,
    # extra_save_steps=range(0, max_steps, 10_000),
    # mixed_precision=True,  # Enable to use Apex for mixed precision training
)

if args.deterministic:
    assert trainer.num_workers <= 1, 'num_workers > 1 introduces indeterministic behavior'

# Archiving training script, src folder, env info
Backup(script_path=__file__, save_path=trainer.save_path).archive_backup()

# Start training
trainer.run(max_steps=max_steps, max_runtime=max_runtime)

# How to re-calculate mean, std and class_weights for other datasets:
#  dataset_mean = utils.calculate_means(train_dataset.inputs)
#  dataset_std = utils.calculate_stds(train_dataset.inputs)
#  class_weights = torch.tensor(utils.calculate_class_weights(train_dataset.targets))
Esempio n. 2
0
def training_thread(acont: ArgsContainer):
    torch.cuda.empty_cache()
    lr = 1e-3
    lr_stepsize = 10000
    lr_dec = 0.995
    max_steps = int(acont.max_step_size / acont.batch_size)

    torch.manual_seed(acont.random_seed)
    np.random.seed(acont.random_seed)
    random.seed(acont.random_seed)

    if acont.use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    lcp_flag = False
    # load model
    if acont.architecture == 'lcp' or acont.model == 'ConvAdaptSeg':
        kwargs = {}
        if acont.model == 'ConvAdaptSeg':
            kwargs = dict(kernel_num=acont.pl, architecture=acont.architecture, activation=acont.act, norm=acont.norm_type)
        conv = dict(layer=acont.conv[0], kernel_separation=acont.conv[1])
        model = ConvAdaptSeg(acont.input_channels, acont.class_num, get_conv(conv), get_search(acont.search), **kwargs)
        lcp_flag = True
    elif acont.use_big:
        model = SegBig(acont.input_channels, acont.class_num, trs=acont.track_running_stats, dropout=acont.dropout,
                       use_bias=acont.use_bias, norm_type=acont.norm_type, use_norm=acont.use_norm,
                       kernel_size=acont.kernel_size, neighbor_nums=acont.neighbor_nums, reductions=acont.reductions,
                       first_layer=acont.first_layer, padding=acont.padding, nn_center=acont.nn_center,
                       centroids=acont.centroids, pl=acont.pl, normalize=acont.cp_norm)
    else:
        model = SegAdapt(acont.input_channels, acont.class_num, architecture=acont.architecture,
                         trs=acont.track_running_stats, dropout=acont.dropout, use_bias=acont.use_bias,
                         norm_type=acont.norm_type, kernel_size=acont.kernel_size, padding=acont.padding,
                         nn_center=acont.nn_center, centroids=acont.centroids, kernel_num=acont.pl,
                         normalize=acont.cp_norm, act=acont.act)

    batch_size = acont.batch_size

    train_transforms = clouds.Compose(acont.train_transforms)
    train_ds = TorchHandler(data_path=acont.train_path, sample_num=acont.sample_num, nclasses=acont.class_num,
                            feat_dim=acont.input_channels, density_mode=acont.density_mode,
                            ctx_size=acont.chunk_size, bio_density=acont.bio_density,
                            tech_density=acont.tech_density, transform=train_transforms,
                            obj_feats=acont.features, label_mappings=acont.label_mappings,
                            hybrid_mode=acont.hybrid_mode, splitting_redundancy=acont.splitting_redundancy,
                            label_remove=acont.label_remove, sampling=acont.sampling, padding=acont.padding,
                            split_on_demand=acont.split_on_demand, split_jitter=acont.split_jitter,
                            epoch_size=acont.epoch_size, workers=acont.workers, voxel_sizes=acont.voxel_sizes,
                            ssd_exclude=acont.ssd_exclude, ssd_include=acont.ssd_include,
                            ssd_labels=acont.ssd_labels, exclude_borders=acont.exclude_borders,
                            rebalance=acont.rebalance, extend_no_pred=acont.extend_no_pred)

    if acont.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif acont.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.5e-5)
    else:
        raise ValueError('Unknown optimizer')

    if acont.scheduler == 'steplr':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
    elif acont.scheduler == 'cosannwarm':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5000, T_mult=2)
    else:
        raise ValueError('Unknown scheduler')

    # calculate class weights if necessary
    weights = None
    if acont.class_weights is not None:
        weights = torch.from_numpy(acont.class_weights).float()

    criterion = torch.nn.CrossEntropyLoss(weight=weights)
    if acont.use_cuda:
        criterion.cuda()

    if acont.use_val:
        val_path = acont.val_path
    else:
        val_path = None

    trainer = Trainer3d(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        train_dataset=train_ds,
        v_path=val_path,
        val_freq=acont.val_freq,
        val_red=acont.val_iter,
        channel_num=acont.input_channels,
        batchsize=batch_size,
        num_workers=4,
        save_root=acont.save_root,
        exp_name=acont.name,
        num_classes=acont.class_num,
        schedulers={"lr": scheduler},
        target_names=acont.target_names,
        stop_epoch=acont.stop_epoch,
        enable_tensorboard=False,
        lcp_flag=lcp_flag,
    )
    # Archiving training script, src folder, env info
    Backup(script_path=__file__, save_path=trainer.save_path).archive_backup()
    acont.save2pkl(trainer.save_path + '/argscont.pkl')
    with open(trainer.save_path + '/argscont.txt', 'w') as f:
        f.write(str(acont.attr_dict))
    f.close()

    trainer.run(max_steps)