Beispiel #1
0
def train(config):
    assert config is not None, "Do not have config file!"

    device = torch.device('cuda:{}'.format(config['gpus']) if config.get(
        'gpus', None) is not None and torch.cuda.is_available() else 'cpu')

    # 1: Load datasets
    set_seed()
    '''CIFAR10'''
    dataset = CIFAR10Dataset('data/CIFAR10/train',
                             'data/CIFAR10/trainLabels.csv')
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [len(dataset) - len(dataset) // 5,
                  len(dataset) // 5])
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=32,
                                                   shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32)
    '''MNIST'''
    # train_dataset = MNISTDataset('data/MNIST/mnist_train.csv')
    # train_dataloader = DataLoader(train_dataset, shuffle=True,
    #                               num_workers=4, batch_size=32)

    # val_dataset = MNISTDataset('data/MNIST/mnist_test.csv')
    # val_dataloader = DataLoader(val_dataset, batch_size=1,
    #                             num_workers=4)

    # 2: Define network
    set_seed()
    net = ClAtNet(in_channels=3,
                  size=(32, 32),
                  nfeatures=64,
                  nclasses=10,
                  nheads=2,
                  dropout=0.0).to(device)
    print(net)
    # 3: Define loss
    criterion = nn.CrossEntropyLoss()
    # 4: Define Optimizer
    optimizer = torch.optim.Adam(net.parameters())
    # 5: Define metrics
    metric = Accuracy()

    # 6: Create trainer
    trainer = Trainer(device=device,
                      config=config,
                      net=net,
                      criterion=criterion,
                      optimier=optimizer,
                      metric=metric)
    # 7: Start to train
    trainer.train(train_dataloader=train_dataloader,
                  val_dataloader=val_dataloader)
def train(config):
    assert config is not None, "Do not have config file!"

    pprint.PrettyPrinter(indent=2).pprint(config)

    dev_id = 'cuda:{}'.format(config['gpus']) \
        if torch.cuda.is_available() and config.get('gpus', None) is not None \
        else 'cpu'
    device = torch.device(dev_id)

    # Get pretrained model
    pretrained_path = config["pretrained"]

    pretrained = None
    if (pretrained_path != None):
        pretrained = torch.load(pretrained_path, map_location=dev_id)
        for item in ["model"]:
            config[item] = pretrained["config"][item]

    # 1: Load datasets
    set_seed()
    train_dataset = get_instance(config['dataset']['train'])
    train_dataloader = get_instance(config['dataset']['train']['loader'],
                                    dataset=train_dataset)

    val_dataset = get_instance(config['dataset']['val'])
    val_dataloader = get_instance(config['dataset']['val']['loader'],
                                  dataset=val_dataset)

    # 2: Define network
    set_seed()
    model = get_instance(config['model']).to(device)

    # Train from pretrained if it is not None
    if pretrained is not None:
        model.load_state_dict(pretrained['model_state_dict'])

    # 3: Define loss
    criterion = get_instance(config['loss']).to(device)

    # 4: Define Optimizer
    optimizer = get_instance(config['optimizer'],
                             params=model.parameters())
    if pretrained is not None:
        optimizer.load_state_dict(pretrained['optimizer_state_dict'])

    # 5: Define Scheduler
    set_seed()
    scheduler = get_instance(config['scheduler'],
                             optimizer=optimizer)

    # 6: Define metrics
    set_seed()
    metric = {mcfg['name']: get_instance(mcfg,
                                         net=model, device=device)
              for mcfg in config['metric']}

    # 6: Create trainer
    trainer = Trainer(device=device,
                      config=config,
                      model=model,
                      criterion=criterion,
                      optimier=optimizer,
                      scheduler=scheduler,
                      metric=metric)

    # 7: Start to train
    set_seed()
    trainer.train(train_dataloader=train_dataloader,
                  val_dataloader=val_dataloader)
Beispiel #3
0
def train(config):
    '''
        The training workflow consists of:
            1. Specify the device to train on (no parallelism yet);
            2. Specify train id, which is the id in config + timestamp;
            3. Load configuration from checkpoint (if specified);
            4. Get network, criterion, optimizer, and callbacks (learning rate scheduler).
               Load pretrained weights if necessary;
            5. Create trainer using all the above;
            6. Get train/val datasets;
            7. Perform training on train/val datasets.
    '''

    # TODO: parallelize training

    # Specify device
    device, dev_id = get_device(config)

    # -----------------------------------------------------------------

    # Training start time
    current_time = time.strftime('%b%d_%H-%M-%S', time.gmtime())
    print('Training starts at', current_time)

    # Get train id
    train_id = get_train_id(config['id'], current_time)

    # -----------------------------------------------------------------

    # TODO: think about continue training on a different datasets

    # Load checkpoint configuration (if specified)
    checkpoint = None
    if config.get('checkpoint', None) is not None:
        print('Continue from checkpoint at %s' % config['checkpoint'])
        checkpoint = torch.load(config['checkpoint'], map_location=dev_id)
        # Override config
        # TODO: what to load (arch is a must, what about the rest)?
        for cfg_item in ['arch', 'loss', 'optimizer', 'scheduler']:
            config[cfg_item] = checkpoint['config'][cfg_item]

    # -----------------------------------------------------------------

    set_seed(manualSeed)

    # Define network
    net = NetworkGetter().get(config=config['arch']).to(device)

    # Define loss
    criterion = LossGetter().get(config=config['loss']).to(device)

    # Define optim
    optimizer = OptimizerGetter().get(params=net.parameters(),
                                      config=config['optimizer'])

    # Define learning rate scheduler
    scheduler = SchedulerGetter().get(optimizer=optimizer,
                                      config=config['scheduler'])

    metrics = {
        metric: MetricGetter().get(config=cfg)
        for metric, cfg in config['metrics'].items()
    }

    # -----------------------------------------------------------------

    # TODO: Summarizer network

    # Print network
    print(net)
    print('=' * 30)

    # -----------------------------------------------------------------

    # Load pretrained weights (if specified)
    if checkpoint is not None:
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # ------------------------------------------------------------------

    # Create trainer
    trainer = Trainer(train_id=train_id,
                      config=config,
                      net=net,
                      optimizer=optimizer,
                      criterion=criterion,
                      scheduler=scheduler,
                      metrics=metrics,
                      device=device)

    # -----------------------------------------------------------------

    set_seed(manualSeed)

    # Load datasets
    train_dataset = CISPDTrain(
        data_path='data/cis-pd/training_data',
        label_path='data/cis-pd/data_labels/CIS-PD_Training_Data_IDs_Labels.csv'
    )
    train_dataset, val_dataset = torch.utils.data.random_split(
        train_dataset, [
            len(train_dataset) - len(train_dataset) // 5,
            len(train_dataset) // 5
        ])
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   num_workers=6,
                                                   batch_size=1,
                                                   shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 num_workers=6,
                                                 batch_size=1)

    # -----------------------------------------------------------------

    # Training
    trainer.train(train_dataloader=train_dataloader,
                  val_dataloader=val_dataloader)

    # -----------------------------------------------------------------

    current_time = time.strftime('%b%d_%H-%M-%S', time.gmtime())
    print('Training finishes at', current_time)
Beispiel #4
0
def train(config):
    assert config is not None, "Do not have config file!"

    pprint.PrettyPrinter(indent=2).pprint(config)

    # Get device
    dev_id = 'cuda:{}'.format(config['gpus']) \
        if torch.cuda.is_available() and config.get('gpus', None) is not None \
        else 'cpu'
    device = torch.device(dev_id)

    # Get pretrained model
    pretrained_path = config["pretrained"]

    pretrained = None
    if (str(pretrained_path) != 'None'):
        pretrained = torch.load(pretrained_path, map_location=dev_id)
        # for item in ["model"]:
        #     config[item] = pretrained["config"][item]

    # 1: Load datasets
    train_dataloader, val_dataloader = \
        get_data(config['dataset'], config['seed'])

    # 2: Define network
    set_seed(config['seed'])
    model = get_instance(config['model']).to(device)

    # if config['parallel']:
    #     print("Load parallel model")
    #     model = nn.DataParallel(model)

    # Train from pretrained if it is not None
    if pretrained is not None:
        pretrained = torch.load(pretrained_path)
        if 'model_state_dict' in pretrained:
            model.load_state_dict(pretrained['model_state_dict'])
        else:
            print("Load model case 2")
            try:
                ret = model.load_state_dict(pretrained, strict=False)
            except RuntimeError as e:
                print(f'[Warning] Ignoring {e}')
                print(
                    '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
                )

    # 3: Define loss
    set_seed(config['seed'])
    criterion = get_instance(config['loss']).to(device)
    criterion.device = device

    # 4: Define Optimizer
    set_seed(config['seed'])
    optimizer = get_instance(config['optimizer'], params=model.parameters())
    # 5: Define Scheduler
    set_seed(config['seed'])
    scheduler = get_instance(config['scheduler'], optimizer=optimizer)

    # 6: Define metrics
    set_seed(config['seed'])
    metric = {mcfg['name']: get_instance(mcfg) for mcfg in config['metric']}

    # 6: Create trainer
    set_seed(config['seed'])
    trainer = Trainer(device=device,
                      config=config,
                      model=model,
                      criterion=criterion,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      metric=metric)

    # 7: Start to train
    set_seed(config['seed'])
    trainer.train(train_dataloader=train_dataloader,
                  val_dataloader=val_dataloader)