Exemplo n.º 1
0
def main(args: argparse.Namespace):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    network = model.utils.get_model(args).to(device)
    criterion = model.utils.get_criterion(args)
    settings = TrainSettingsDecoder().decode(args, network)
    mean, std = dataset.get_stats()
    spatial_transforms = dataset.SpatialTransformRepository(
        mean, std).get_transform_obj(args.transforms_json)
    temporal_transforms = dataset.TemporalTransformRepository(
    ).get_transform_obj(args.transforms_json)
    with open(args.config_json) as f:
        config: Mapping[str, Any] = json.load(f)
    videodata_repository = dataset.get_dataset(
        args,
        spatial_transforms,
        temporal_transforms,
        **config,
    )
    train_loader = DataLoader(
        videodata_repository,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=dataset.collate_data,
        drop_last=False,
    )
    trainer = train.Trainer(settings, network, criterion, device)
    trainer.train(train_loader)
Exemplo n.º 2
0
def main():
    config, args = parse_arg()
    model = crnn.get_crnn(config)
    criterion = torch.nn.CTCLoss()
    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")
    model = model.to(device)
    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    acc = validate(config, val_loader, val_dataset, converter, model,
                   criterion, device)
Exemplo n.º 3
0
def main(hpt):
    logger.info('load New Data From Matlab')
    if hpt.dataset.type == 'synthetic':
        treeFile = loadmat(hpt.dataset['matrixFile_struct'])
        matrixForData = treeFile.get('matrixForHS')
        hpt.dataset.__setitem__('depth', treeFile.get('depthToSave')[0, 0])
        hpt.dataset.__setitem__('depthReal',
                                treeFile.get('depthToSaveReal')[0, 0])
        hpt.training.__setitem__('batch_size', treeFile.get('baches')[0, 0])
        print(treeFile.get('matrixForHS'))
        print(hpt.dataset.depth)
        print(hpt.training.batch_size)
    elif hpt.dataset.type == 'mnist':
        activityFile = loadmat(hpt.dataset['matrixFile_activity'])
        matrixForData = np.transpose(activityFile.get('roiActivity'))
        hpt.dataset.__setitem__('mnistShape', len(matrixForData[1, :]))
        hpt.training.__setitem__('batch_size', len(matrixForData[:, 1]))

    logger.info('build model')
    avg_elbo_loss = get_model(hpt)
    if hpt.general.gpu >= 0:
        avg_elbo_loss.to_gpu(hpt.general.gpu)

    logger.info('setup optimizer')
    if hpt.optimizer.type == 'adam':
        optimizer = chainer.optimizers.Adam(alpha=hpt.optimizer.lr)
    optimizer.setup(avg_elbo_loss)

    logger.info('load dataset')
    train, valid, test = dataset.get_dataset(hpt.dataset.type, matrixForData,
                                             **hpt.dataset)

    if hpt.general.test:
        train, _ = chainer.datasets.split_dataset(train, 100)
        valid, _ = chainer.datasets.split_dataset(valid, 100)
        test, _ = chainer.datasets.split_dataset(test, 100)

    train_iter = chainer.iterators.SerialIterator(train,
                                                  hpt.training.batch_size)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  hpt.training.batch_size,
                                                  repeat=False,
                                                  shuffle=False)

    logger.info('setup updater/trainer')
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=hpt.general.gpu,
                                                loss_func=avg_elbo_loss)

    if not hpt.training.early_stopping:
        trainer = training.Trainer(updater,
                                   (hpt.training.iteration, 'iteration'),
                                   out=po.namedir(output='str'))
    else:
        trainer = training.Trainer(updater,
                                   triggers.EarlyStoppingTrigger(
                                       monitor='validation/main/loss',
                                       patients=5,
                                       max_trigger=(hpt.training.iteration,
                                                    'iteration')),
                                   out=po.namedir(output='str'))

    if hpt.training.warm_up != -1:
        time_range = (0, hpt.training.warm_up)
        trainer.extend(
            extensions.LinearShift('beta',
                                   value_range=(0.1, hpt.loss.beta),
                                   time_range=time_range,
                                   optimizer=avg_elbo_loss))

    trainer.extend(
        extensions.Evaluator(valid_iter, avg_elbo_loss,
                             device=hpt.general.gpu))
    # trainer.extend(extensions.DumpGraph('main/loss'))
    trainer.extend(extensions.snapshot_object(
        avg_elbo_loss, 'avg_elbo_loss_snapshot_iter_{.updater.iteration}'),
                   trigger=(int(hpt.training.iteration / 5), 'iteration'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/reconstr', 'main/kl_penalty', 'main/beta', 'lr',
            'elapsed_time'
        ]))
    trainer.extend(extensions.ProgressBar())

    logger.info('run training')
    trainer.run()

    logger.info('save last model')
    extensions.snapshot_object(
        avg_elbo_loss,
        'avg_elbo_loss_snapshot_iter_{.updater.iteration}')(trainer)

    logger.info('evaluate')
    metrics = evaluate(hpt, train, test, avg_elbo_loss)
    for metric_name, metric in metrics.items():
        logger.info('{}: {:.4f}'.format(metric_name, metric))

    if hpt.general.noplot:
        return metrics

    logger.info('visualize images')
    visualize(hpt, train, test, avg_elbo_loss, treeFile)

    return metrics
Exemplo n.º 4
0
def main():

    # load config
    config = parse_arg()

    # create output folder
    output_dict = utils.create_log_folder(config, phase='train')

    # cudnn
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # writer dict
    writer_dict = {
        'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # construct face related neural networks
    model = crnn.get_crnn(config)

    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")

    model = model.to(device)

    # define loss function
    criterion = torch.nn.CTCLoss()

    optimizer = utils.get_optimizer(config, model)

    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME.IS_RESUME:
        model_state_file = config.TRAIN.RESUME.FILE
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
        last_epoch = checkpoint['epoch']

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch - 1
        )

    train_dataset = get_dataset(config)(config, is_train=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    best_acc = 0.5
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):

        function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict)
        lr_scheduler.step()

        acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)

        is_best = acc > best_acc
        best_acc = max(acc, best_acc)

        print("is best:", is_best)
        print("best acc is:", best_acc)
        # save checkpoint
        torch.save(
            {
                "state_dict": model.state_dict(),
                "epoch": epoch + 1,
                "best_acc": best_acc,
            },  os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))
        )

    writer_dict['writer'].close()
def main():

    # load config
    config = parse_arg()

    # create output folder
    output_dict = utils.create_log_folder(config, phase='train')

    # cudnn
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # writer dict
    writer_dict = {
        'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # construct face related neural networks
    model = crnn.get_crnn(config)
    #
    # checkpoint = torch.load('/data/yolov5/CRNN_Chinese_Characters_Rec/output/OWN/crnn/2020-09-15-22-13/checkpoints/checkpoint_98_acc_1.0983.pth')
    # if 'state_dict' in checkpoint.keys():
    #     model.load_state_dict(checkpoint['state_dict'])
    # else:
    #     model.load_state_dict(checkpoint)
    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")

    model = model.to(device)

    # define loss function
    # criterion = torch.nn.CTCLoss()
    criterion = CTCLoss()

    last_epoch = config.TRAIN.BEGIN_EPOCH
    optimizer = utils.get_optimizer(config, model)
    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    if config.TRAIN.FINETUNE.IS_FINETUNE:
        model_state_file = config.TRAIN.FINETUNE.FINETUNE_CHECKPOINIT
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']

        from collections import OrderedDict
        model_dict = OrderedDict()
        for k, v in checkpoint.items():
            if 'cnn' in k:
                model_dict[k[4:]] = v
        model.cnn.load_state_dict(model_dict)
        if config.TRAIN.FINETUNE.FREEZE:
            for p in model.cnn.parameters():
                p.requires_grad = False

    elif config.TRAIN.RESUME.IS_RESUME:
        model_state_file = config.TRAIN.RESUME.FILE
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        if 'state_dict' in checkpoint.keys():
            model.load_state_dict(checkpoint['state_dict'])
            last_epoch = checkpoint['epoch']
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        else:
            model.load_state_dict(checkpoint)

    model_info(model)
    train_dataset = get_dataset(config)(config, is_train=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    best_acc = 0.5
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):

        function.train(config, train_loader, train_dataset, converter, model,
                       criterion, optimizer, device, epoch, writer_dict,
                       output_dict)
        lr_scheduler.step()

        acc = function.validate(config, val_loader, val_dataset, converter,
                                model, criterion, device, epoch, writer_dict,
                                output_dict)

        is_best = acc > best_acc
        best_acc = max(acc, best_acc)

        print("is best:", is_best)
        print("best acc is:", best_acc)
        # save checkpoint
        torch.save(
            {
                "state_dict": model.state_dict(),
                "epoch": epoch + 1,
                # "optimizer": optimizer.state_dict(),
                # "lr_scheduler": lr_scheduler.state_dict(),
                "best_acc": best_acc,
            },
            os.path.join(output_dict['chs_dir'],
                         "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc)))

    writer_dict['writer'].close()
Exemplo n.º 6
0
def main(hpt):

    logger.info('load dataset')
    train, valid, test = dataset.get_dataset(hpt.dataset.type, **hpt.dataset)
    assert valid is None
    assert test is None

    if hpt.general.test:
        train, _ = chainer.datasets.split_dataset(train, 100)
        chainer.set_debug(True)

    train_iter = chainer.iterators.SerialIterator(train,
                                                  hpt.training.batch_size)

    logger.info('build model')
    loss = get_model(hpt)
    if hpt.general.gpu >= 0:
        loss.to_gpu(hpt.general.gpu)

    logger.info('setup optimizer')
    if hpt.optimizer.type == 'adam':
        optimizer = chainer.optimizers.Adam(alpha=hpt.optimizer.lr)
    elif hpt.optimizer.type == 'adagrad':
        optimizer = chainer.optimizers.AdaGrad(lr=hpt.optimizer.lr)
    else:
        raise AttributeError
    optimizer.setup(loss)

    logger.info('setup updater/trainer')
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=hpt.general.gpu,
                                                loss_func=loss)

    trainer = training.Trainer(updater, (hpt.training.iteration, 'iteration'),
                               out=po.namedir(output='str'))

    lr_name = 'alpha' if hpt.optimizer.type == 'adam' else 'lr'
    trainer.extend(
        Burnin(lr_name, burnin_step=hpt.training.burnin_step,
               c=hpt.training.c))

    trainer.extend(extensions.FailOnNonNumber())

    trainer.extend(extensions.snapshot_object(
        loss, 'loss_snapshot_iter_{.updater.iteration}'),
                   trigger=(int(hpt.training.iteration / 5), 'iteration'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'main/kl_target',
            'main/kl_negative', 'lr', 'main/bound', 'elapsed_time'
        ]))
    trainer.extend(extensions.ProgressBar())

    # Save plot images to the result dir
    if (not hpt.general.noplot) and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss'],
                                  'epoch',
                                  file_name=(po.imagesdir() /
                                             'loss.png').as_posix()))
        trainer.extend(
            extensions.PlotReport(['main/kl_target', 'main/kl_negative'],
                                  'epoch',
                                  file_name=(po.imagesdir() /
                                             'kldiv.png').as_posix()))

    # Run the training
    logger.info('run training')
    trainer.run()

    logger.info('evaluate')
    metrics = evaluate(hpt, train, test, loss)
    for metric_name, metric in metrics.items():
        logger.info('{}: {:.4f}'.format(metric_name, metric))

    if hpt.general.noplot:
        return metrics

    return metrics
Exemplo n.º 7
0
    # model.load_state_dict(torch.load(args.checkpoint))

    model_state_file = args.checkpoint
    if model_state_file == '':
        print(" => no checkpoint found")
    checkpoint = torch.load(model_state_file, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])

    # converter
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)  # get corpus

    # define loss function
    criterion = torch.nn.CTCLoss()

    if args.mode == "train":
        data_set = get_dataset(config)(config, is_train=True)
        data_loader = DataLoader(
            dataset=data_set,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
            shuffle=config.TRAIN.SHUFFLE,
            num_workers=config.WORKERS,
            pin_memory=config.PIN_MEMORY,
        )
    elif args.mode == "test":
        data_set = get_dataset(config)(config, is_train=False)
        data_loader = DataLoader(
            dataset=data_set,
            batch_size=config.TEST.BATCH_SIZE_PER_GPU,
            shuffle=config.TEST.SHUFFLE,
            num_workers=config.WORKERS,
            pin_memory=config.PIN_MEMORY,