コード例 #1
0
ファイル: train.py プロジェクト: birds-on-mars/birdsonearth
def start_training_with(params):
    '''
    takes a params object and expects ready to be used spectrograms
    in params.mel_spec_root.
    Sets up all requirements for training, runs the training and returns the
    trained model
    '''
    # setup
    device = torch.device(params.device)
    n_classes = len(os.listdir(params.mel_spec_root))
    params.n_classes = n_classes
    print('setting up training for {} classes'.format(n_classes))
    dataset = d.MelSpecDataset(params)
    net = m.VGGish(params)
    net.init_weights()
    net.freeze_bottom()
    new_top = torch.nn.Linear(net.out_dims * 512, net.n_classes)
    net.classifier = new_top
    net.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    trainer = t.Trainer(net, dataset, criterion, optimizer, params, device)

    # starting training
    print('start training on {} for {} epochs'.format(device, params.n_epochs))
    trainer.run_training()

    # saving model weights and class labels
    if params.save_model:
        print('saving weights and class labels')
        net.save_weights()
        print(dataset.labels)
        with open(os.path.join(params.model_zoo, params.name + '.pkl'),
                  'wb') as f:
            pickle.dump(dataset.labels, f)

    return net, dataset.labels
コード例 #2
0
ファイル: train.py プロジェクト: tangh/simple-pytorch-fcn
    if args.resume and "scheduler_state_dict" in checkpoint:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        logger.info("Resume scheduler state dict from checkpoint.")
else:
    scheduler = None

loss_reduction = "mean" if args.normalize_loss else "sum"
criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction=loss_reduction)


# --------------------------------------------------------------------------- #
# strat training
# --------------------------------------------------------------------------- #
fcn_trainer = trainer.Trainer(
    device=device,
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    save_dir=args.save_dir,
    max_iter=args.max_iter,
    validate_interval=4000,
)
fcn_trainer.epoch = start_epoch
fcn_trainer.iteration = start_iteration
fcn_trainer.best_mean_iou = best_mean_iou

fcn_trainer.train()
コード例 #3
0
def main():
    parser = ArgumentParser()

    parser.add_argument('--train_path',
                        required=True,
                        help='The path to training images')

    parser.add_argument('--mask_path',
                        required=True,
                        help='The path to mask images')

    parser.add_argument(
        '-warm_up_generator',
        action='store_true',
        help='Training generator model only with reconstruction loss')

    parser.add_argument(
        '-from_weights',
        action='store_true',
        help='Use this command to continue training from weights')

    parser.add_argument('--gpu',
                        default='0',
                        help='index of GPU to be used (default: %(default))')

    args = parser.parse_args()

    training_utils.set_visible_gpu(args.gpu)
    if args.warm_up_generator:
        log.info(
            'Performing generator training only with the reconstruction loss.')

    config = main_config.MainConfig(MAIN_CONFIG_FILE)
    wgan_batch_size = config.training.wgan_training_ratio * config.training.batch_size

    train_path = os.path.expanduser(args.train_path)
    mask_path = os.path.expanduser(args.mask_path)

    gmcnn_gan_model = gmcnn_gan.GMCNNGan(
        batch_size=config.training.batch_size,
        img_height=config.training.img_height,
        img_width=config.training.img_width,
        num_channels=config.training.num_channels,
        warm_up_generator=args.warm_up_generator,
        config=config)

    if args.from_weights:
        log.info('Continue training from checkpoint...')
        gmcnn_gan_model.load()

    img_dataset = datasets.Dataset(train_path=train_path,
                                   test_path=train_path,
                                   batch_size=wgan_batch_size,
                                   img_height=config.training.img_height,
                                   img_width=config.training.img_width)

    if img_dataset.train_set.samples < wgan_batch_size:
        log.error(
            'Number of training images [%s] is lower than WGAN batch size [%s]',
            img_dataset.train_set.samples, wgan_batch_size)
        exit(0)

    mask_dataset = datasets.MaskDataset(train_path=mask_path,
                                        batch_size=wgan_batch_size,
                                        img_height=config.training.img_height,
                                        img_width=config.training.img_width)

    if mask_dataset.train_set.samples < wgan_batch_size:
        log.error(
            'Number of training mask images [%s] is lower than WGAN batch size [%s]',
            mask_dataset.train_set.samples, wgan_batch_size)
        exit(0)

    gmcnn_gan_trainer = trainer.Trainer(
        gan_model=gmcnn_gan_model,
        img_dataset=img_dataset,
        mask_dataset=mask_dataset,
        batch_size=config.training.batch_size,
        img_height=config.training.img_height,
        img_width=config.training.img_width,
        num_epochs=config.training.num_epochs,
        save_model_steps_period=config.training.save_model_steps_period)

    gmcnn_gan_trainer.train()
コード例 #4
0
def main():
    parser = ArgumentParser()

    parser.add_argument('--train_path',
                        required=True,
                        help='The path to training images')

    parser.add_argument('--mask_path',
                        required=True,
                        help='The path to mask images')

    parser.add_argument('--experiment_name',
                        required=True,
                        help='The name of experiment')

    parser.add_argument(
        '-warm_up_generator',
        action='store_true',
        help='Training generator model only with reconstruction loss')

    parser.add_argument(
        '-from_weights',
        action='store_true',
        help='Use this command to continue training from weights')

    parser.add_argument('--gpu',
                        default='0',
                        help='index of GPU to be used (default: %(default))')

    args = parser.parse_args()

    output_paths = constants.OutputPaths(experiment_name=args.experiment_name)
    training_utils.set_visible_gpu(args.gpu)
    if args.warm_up_generator:
        log.info(
            'Performing generator training only with the reconstruction loss.')

    config = main_config.MainConfig(MAIN_CONFIG_FILE)
    wgan_batch_size = config.training.wgan_training_ratio * config.training.batch_size

    train_path = os.path.expanduser(args.train_path)
    mask_path = os.path.expanduser(args.mask_path)

    gmcnn_gan_model = gmcnn_gan.GMCNNGan(
        batch_size=config.training.batch_size,
        img_height=config.training.img_height,
        img_width=config.training.img_width,
        num_channels=config.training.num_channels,
        warm_up_generator=args.warm_up_generator,
        config=config,
        output_paths=output_paths)

    #if args.from_weights:
    #  log.info('Continue training from checkpoint...')
    #  gmcnn_gan_model.load()

    # look for newest weights
    weights_folder = output_paths.output_weights_path
    folders = [f.path for f in os.scandir(weights_folder) if f.is_dir()]
    folders.sort()
    last_folder = folders[-1]
    print("Loading weights from folder: %s" % last_folder)
    gmcnn_gan_model.load(last_folder)

    img_dataset = datasets.Dataset(train_path=train_path,
                                   test_path=train_path,
                                   batch_size=wgan_batch_size,
                                   img_height=config.training.img_height,
                                   img_width=config.training.img_width)

    if img_dataset.train_set.samples < wgan_batch_size:
        log.error(
            'Number of training images [%s] is lower than WGAN batch size [%s]',
            img_dataset.train_set.samples, wgan_batch_size)
        exit(0)

    mask_dataset = datasets.MaskDataset(train_path=mask_path,
                                        batch_size=wgan_batch_size,
                                        img_height=config.training.img_height,
                                        img_width=config.training.img_width)

    if mask_dataset.train_set.samples < wgan_batch_size:
        log.error(
            'Number of training mask images [%s] is lower than WGAN batch size [%s]',
            mask_dataset.train_set.samples, wgan_batch_size)
        exit(0)

    gmcnn_gan_trainer = trainer.Trainer(
        gan_model=gmcnn_gan_model,
        img_dataset=img_dataset,
        mask_dataset=mask_dataset,
        batch_size=config.training.batch_size,
        img_height=config.training.img_height,
        img_width=config.training.img_width,
        num_epochs=config.training.num_epochs,
        save_model_steps_period=config.training.save_model_steps_period,
        output_paths=output_paths,
        callback=callback)

    gmcnn_gan_trainer.train()

    gmcnn_gan_trainer.gan_model.save()