Esempio n. 1
0
def main(config):

    train_dataset = Dataset_DAVIS(config['train_data_dir'])
    train_loader = Dataloader_DAVIS(dataset=train_dataset, batch_size=1, shuffle=False, validation_split=0.0, num_workers=1, training=True)

    valid_dataset = Dataset_DAVIS(config['valid_data_dir'])
    valid_loader = Dataloader_DAVIS(dataset=valid_dataset, batch_size=1, shuffle=False, validation_split=0.0, num_workers=1, training=False)

    trainer = Trainer(train_loader=train_loader, config=config, valid_loader=valid_loader)
    trainer.test("results/test_results.jpg")
Esempio n. 2
0
# -*- coding: utf-8 -*-

import os
from trainer.trainer import Trainer


trainer = Trainer()
trainer.train(10, 1, False)
trainer.test()
Esempio n. 3
0
                           question_path=config.question_path,
                           batch_size=config.batch_size,
                           window_size=saved_config.window_size,
                           device=config.device)

    test = loader.load_test_data()
    if config.model == "SeqModel3":
        from model.model import SeqModel3
        from trainer.sequence_trainer import seq_Trainer
        model = SeqModel3(input_size=options.get_movie_size(),
                          word_vec_dim=saved_config.word_vec_dim,
                          hidden_size=saved_config.hidden_size,
                          output_size=options.get_movie_size(),
                          device=config.device).to(config.device)

        crits = nn.NLLLoss()
        trainer = seq_Trainer(model, crits, config, options)
    else:
        from model.model import SeqModel
        from trainer.trainer import Trainer

        model = SeqModel(input_size=options.get_input_size(),
                         hidden_size=saved_config.hidden_size,
                         output_size=options.get_movie_size(),
                         n_layers=4,
                         device=config.device).to(config.device)

        crits = nn.NLLLoss()
        trainer = Trainer(model, crits, config, options)
    trainer.test(test)
Esempio n. 4
0
def main(args, config):

    
    model = AudioOnly(8, base_model=args.arch)
 
    import torchaudio.transforms as at

    t = []
    if args.masking_time != 0:
        t.append(at.TimeMasking(args.masking_time))

    if args.masking_freq != 0:
        t.append(at.FrequencyMasking(args.masking_freq))

    transform = transforms.Compose(t)

    dataset = AudioDataSet("train", transform=transform)

    val_transform = transforms.Compose([
       
        ])
 
    sampler = None
    

    train_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, collate_fn=None, drop_last=False)

    val_loader = torch.utils.data.DataLoader(
        AudioDataSet("val",transform=val_transform),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, collate_fn=None)


    logger = config.get_logger('train')
    logger.info(model)

    criterion_categorical = getattr(module_loss, config['loss'])
    criterion_continuous = getattr(module_loss, config['loss_continuous'])

    metrics = [getattr(module_metric, met) for met in config['metrics']]
    metrics_continuous = [getattr(module_metric, met) for met in config['metrics_continuous']]

    # policies = model.get_optim_policies(lr=args.lr)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    for param_group in optimizer.param_groups:
        print(param_group['lr'])
    trainer = Trainer(model, criterion_categorical, criterion_continuous, metrics, metrics_continuous, optimizer,
                      categorical=True,
                      continuous=False,
                      config=config,
                      data_loader=train_loader,
                      valid_data_loader=val_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()


    test_loader = torch.utils.data.DataLoader(
        AudioDataSet("test",transform=val_transform),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, collate_fn=None)
   

    """ load best model and test """
    cp = torch.load(str(trainer.checkpoint_dir / 'model_best.pth'))

    model.load_state_dict(cp['state_dict'],strict=True)
    print('loaded', str(trainer.checkpoint_dir / 'model_best.pth'), 'best_epoch', cp['epoch'])

    trainer = Trainer(model, criterion_categorical, criterion_continuous, metrics, metrics_continuous, optimizer,
                      categorical=True,
                      continuous=False,
                      config=config,
                      data_loader=train_loader,
                      valid_data_loader=test_loader,
                      lr_scheduler=lr_scheduler)


    trainer.test()
Esempio n. 5
0
def main(config):
    """Training."""
    logger = config.get_logger('train')

    # setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)
    valid_data_loader = data_loader.split_dataset(valid=True)
    test_data_loader = data_loader.split_dataset(test=True)
    feature_index = data_loader.get_feature_index()
    cell_neighbor_set = data_loader.get_cell_neighbor_set()
    drug_neighbor_set = data_loader.get_drug_neighbor_set()
    node_num_dict = data_loader.get_node_num_dict()

    model = module_arch(
        protein_num=node_num_dict['protein'],
        cell_num=node_num_dict['cell'],
        drug_num=node_num_dict['drug'],
        emb_dim=config['arch']['args']['emb_dim'],
        n_hop=config['arch']['args']['n_hop'],
        l1_decay=config['arch']['args']['l1_decay'],
        therapy_method=config['arch']['args']['therapy_method'])
    logger.info(model)

    # get function handles of loss and metrics
    criterion = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj('optimizer', torch.optim, trainable_params)

    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler,
                                   optimizer)

    trainer = Trainer(model,
                      criterion,
                      metrics,
                      optimizer,
                      config=config,
                      data_loader=data_loader,
                      feature_index=feature_index,
                      cell_neighbor_set=cell_neighbor_set,
                      drug_neighbor_set=drug_neighbor_set,
                      valid_data_loader=valid_data_loader,
                      test_data_loader=test_data_loader,
                      lr_scheduler=lr_scheduler)
    trainer.train()
    """Testing."""
    logger = config.get_logger('test')
    logger.info(model)
    test_metrics = [getattr(module_metric, met) for met in config['metrics']]

    # load best checkpoint
    resume = str(config.save_dir / 'model_best.pth')
    logger.info('Loading checkpoint: {} ...'.format(resume))
    checkpoint = torch.load(resume)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)

    test_output = trainer.test()
    log = {'loss': test_output['total_loss'] / test_output['n_samples']}
    log.update({
        met.__name__: test_output['total_metrics'][i].item() / test_output['n_samples'] \
            for i, met in enumerate(test_metrics)
    })
    logger.info(log)
Esempio n. 6
0
def main(args, config):

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality == "depth":
        data_length = args.data_length
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = args.data_length

    model = TSN(8,
                args.num_segments,
                args.modality,
                modalities_fusion=args.modalities_fusion,
                num_feats=args.num_feats,
                base_model=args.arch,
                new_length=data_length,
                embed=args.embed,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn,
                categorical=args.categorical,
                continuous=args.continuous,
                audio=args.audio)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation()

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    dataset = TSNDataSet(
        "train",
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality
        in ["RGB", "RGBDiff", "depth"] else args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            GroupScale((256, 256)),
            GroupRandomHorizontalFlip(),
            GroupRandomCrop(224),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ]))

    collate_fn = None
    sampler = None

    train_loader = torch.utils.data.DataLoader(dataset,
                                               sampler=sampler,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               collate_fn=collate_fn,
                                               drop_last=False)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "val",
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="img_{:05d}.jpg" if args.modality
        in ["RGB", "RGBDiff", "depth"] else args.flow_prefix + "{}_{:05d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale((int(224), int(224))),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             collate_fn=collate_fn)

    logger = config.get_logger('train')
    logger.info(model)

    criterion_categorical = getattr(module_loss, config['loss'])
    criterion_continuous = getattr(module_loss, config['loss_continuous'])

    metrics = [getattr(module_metric, met) for met in config['metrics']]
    metrics_continuous = [
        getattr(module_metric, met) for met in config['metrics_continuous']
    ]

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler,
                                   optimizer)

    for param_group in optimizer.param_groups:
        print(param_group['lr'])
    trainer = Trainer(model,
                      criterion_categorical,
                      criterion_continuous,
                      metrics,
                      metrics_continuous,
                      optimizer,
                      categorical=args.categorical,
                      continuous=args.continuous,
                      config=config,
                      data_loader=train_loader,
                      valid_data_loader=val_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()

    test_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            "test",
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl="img_{:05d}.jpg" if args.modality
            in ["RGB", "RGBDiff", "depth"] else args.flow_prefix +
            "{}_{:05d}.jpg",
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale((int(224), int(224))),
                # GroupCenterCrop(crop_size),
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn)

    # load best model and evaluate on test
    cp = torch.load(str(trainer.checkpoint_dir / 'model_best.pth'))

    model.load_state_dict(cp['state_dict'], strict=True)
    print('loaded', str(trainer.checkpoint_dir / 'model_best.pth'),
          'best_epoch', cp['epoch'])

    trainer = Trainer(model,
                      criterion_categorical,
                      criterion_continuous,
                      metrics,
                      metrics_continuous,
                      optimizer,
                      categorical=args.categorical,
                      continuous=args.continuous,
                      config=config,
                      data_loader=train_loader,
                      valid_data_loader=test_loader,
                      lr_scheduler=lr_scheduler)

    trainer.test()