Ejemplo n.º 1
0
def train(dataset_dir, checkpoint, restore, tracking, cuda, epochs, batch_size,
          learning_rate, lr_factor, momentum, optimizer, augmentation,
          device_ids, num_workers, weight_decay, validation, evaluate, shuffle,
          half, arch):
    timestamp = "{:.0f}".format(datetime.utcnow().timestamp())
    local_timestamp = str(datetime.now())
    config = {k: v for k, v in locals().items()}

    use_cuda = cuda and torch.cuda.is_available()

    # create model
    model = MODELS[arch]()

    # create optimizer
    if optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    elif optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif optimizer == 'yellowfin':
        optimizer = YFOptimizer(model.parameters(),
                                lr=learning_rate,
                                mu=momentum,
                                weight_decay=weight_decay)

    else:
        raise NotImplementedError("Unknown optimizer: {}".format(optimizer))

    if restore is not None:
        if restore == 'latest':
            restore = utils.latest_file(arch)
        print(f'Restoring model from {restore}')
        assert os.path.exists(restore)
        restored_state = torch.load(restore)
        assert restored_state['arch'] == arch

        model.load_state_dict(restored_state['model'])
        if 'optimizer' in restored_state:
            optimizer.load_state_dict(restored_state['optimizer'])
        if not isinstance(optimizer, YFOptimizer):
            for group in optimizer.param_groups:
                group['lr'] = learning_rate

        best_accuracy = restored_state['accuracy']
        start_epoch = restored_state['epoch'] + 1
        run_dir = os.path.split(restore)[0]
    else:
        best_accuracy = 0.0
        start_epoch = 1
        run_dir = f"./run/{arch}/{timestamp}"

    print('Starting accuracy is {}'.format(best_accuracy))

    if not os.path.exists(run_dir) and run_dir != '':
        os.makedirs(run_dir)
    utils.save_config(config, run_dir)

    print(model)
    print("{} parameters".format(utils.count_parameters(model)))
    print(f"Run directory set to {run_dir}")

    # Save model text description
    with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
        file.write(str(model))

    if tracking:
        train_results_file = os.path.join(run_dir, 'train_results.csv')
        valid_results_file = os.path.join(run_dir, 'valid_results.csv')
        test_results_file = os.path.join(run_dir, 'test_results.csv')
    else:
        train_results_file = None
        valid_results_file = None
        test_results_file = None

    # create loss
    criterion = nn.CrossEntropyLoss()

    if use_cuda:
        print('Copying model to GPU')
        model = model.cuda()
        criterion = criterion.cuda()

        if half:
            model = model.half()
            criterion = criterion.half()
        device_ids = device_ids or list(range(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        num_workers = num_workers or len(device_ids)
    else:
        num_workers = num_workers or 1
        if half:
            print('Half precision (16-bit floating point) only works on GPU')
    print(f"using {num_workers} workers for data loading")

    # load data
    print("Preparing data:")
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root=dataset_dir,
                         train=False,
                         download=True,
                         transform=transform_test),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=use_cuda)

    if evaluate:
        print("Only running evaluation of model on test dataset")
        run(start_epoch - 1,
            model,
            test_loader,
            use_cuda=use_cuda,
            tracking=test_results_file,
            train=False)
        return

    if augmentation:
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ]
    else:
        transform_train = []

    transform_train = transforms.Compose(transform_train + [
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    train_dataset = datasets.CIFAR10(root=dataset_dir,
                                     train=True,
                                     download=True,
                                     transform=transform_train)

    num_train = len(train_dataset)
    indices = list(range(num_train))
    assert 1 > validation and validation >= 0, "Validation must be in [0, 1)"
    split = num_train - int(validation * num_train)

    if shuffle:
        np.random.shuffle(indices)

    train_indices = indices[:split]
    valid_indices = indices[split:]

    print('Using {} examples for training'.format(len(train_indices)))
    print('Using {} examples for validation'.format(len(valid_indices)))

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(valid_indices)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=train_sampler,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               pin_memory=use_cuda)
    if validation != 0:
        valid_loader = torch.utils.data.DataLoader(train_dataset,
                                                   sampler=valid_sampler,
                                                   batch_size=batch_size,
                                                   num_workers=num_workers,
                                                   pin_memory=use_cuda)
    else:
        print('Using test dataset for validation')
        valid_loader = test_loader

    end_epoch = start_epoch + epochs
    # YellowFin doesn't have param_groups causing AttributeError
    if not isinstance(optimizer, YFOptimizer):
        for group in optimizer.param_groups:
            if 'lr' in group:
                print('Learning rate set to {}'.format(group['lr']))
                assert group['lr'] == learning_rate
    else:
        print(f"set lr_factor to {lr_factor}")
        optimizer.set_lr_factor(lr_factor)
    for epoch in range(start_epoch, end_epoch):
        run(epoch,
            model,
            train_loader,
            criterion,
            optimizer,
            use_cuda=use_cuda,
            tracking=train_results_file,
            train=True,
            half=half)

        valid_acc = run(epoch,
                        model,
                        valid_loader,
                        use_cuda=use_cuda,
                        tracking=valid_results_file,
                        train=False,
                        half=half)

        is_best = valid_acc > best_accuracy
        last_epoch = epoch == (end_epoch - 1)
        if is_best or checkpoint == 'all' or (checkpoint == 'last'
                                              and last_epoch):
            state = {
                'epoch': epoch,
                'arch': arch,
                'model': (model.module if use_cuda else model).state_dict(),
                'accuracy': valid_acc,
                'optimizer': optimizer.state_dict()
            }
        if is_best:
            print('New best model!')
            filename = os.path.join(run_dir, 'checkpoint_best_model.t7')
            print(f'Saving checkpoint to {filename}')
            best_accuracy = valid_acc
            torch.save(state, filename)
        if checkpoint == 'all' or (checkpoint == 'last' and last_epoch):
            filename = os.path.join(run_dir, f'checkpoint_{epoch}.t7')
            print(f'Saving checkpoint to {filename}')
            torch.save(state, filename)
Ejemplo n.º 2
0
def train(dataset_dir, checkpoint, restore, tracking, cuda, epochs, batch_size,
          learning_rate, learning_rate_decay, learning_rate_freq, momentum,
          optimizer, augmentation, pretrained, evaluate, num_workers,
          weight_decay, arch):
    timestamp = "{:.0f}".format(datetime.utcnow().timestamp())
    config = {k: v for k, v in locals().items()}

    use_cuda = cuda and torch.cuda.is_available()

    # create model
    if pretrained:
        print("=> using pre-trained model '{}'".format(arch))
        model = models.__dict__[arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(arch))
        model = models.__dict__[arch]()

    if optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    elif optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    learning_rate,
                                    momentum=momentum,
                                    weight_decay=weight_decay)
    else:
        raise NotImplementedError("Unknown optimizer: {}".format(optimizer))

    # optionally resume from a checkpoint
    if restore is not None:
        if restore == 'latest':
            restore = utils.latest_file(arch)
        print(f'=> restoring model from {restore}')
        restored_state = torch.load(restore)
        start_epoch = restored_state['epoch'] + 1
        best_prec1 = restored_state['prec1']
        model.load_state_dict(restored_state['state_dict'])
        optimizer.load_state_dict(restored_state['optimizer'])
        print('=> starting accuracy is {} (epoch {})'.format(
            best_prec1, start_epoch))
        run_dir = os.path.split(restore)[0]
    else:
        best_prec1 = 0.0
        start_epoch = 1
        run_dir = f"./run/{arch}/{timestamp}"

    if not os.path.exists(run_dir):
        os.makedirs(run_dir)
    utils.save_config(config, run_dir)

    print(model)
    print("{} parameters".format(utils.count_parameters(model)))
    print(f"Run directory set to {run_dir}")

    # save model text description
    with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
        file.write(str(model))

    if tracking:
        train_results_file = os.path.join(run_dir, 'train_results.csv')
        test_results_file = os.path.join(run_dir, 'test_results.csv')
    else:
        train_results_file = None
        test_results_file = None

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    # move model and criterion to GPU
    if use_cuda:
        model.cuda()
        criterion = criterion.cuda()
        model = torch.nn.parallel.DataParallel(model)
        num_workers = num_workers or torch.cuda.device_count()
    else:
        num_workers = num_workers or 1
    print(f"=> using {num_workers} workers for data loading")

    cudnn.benchmark = True

    # Data loading code
    print("=> preparing data:")
    traindir = os.path.join(dataset_dir, 'train')
    valdir = os.path.join(dataset_dir, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers,
                                             pin_memory=True)

    if evaluate:
        validate(val_loader, model, criterion)
        return

    end_epoch = start_epoch + epochs
    for epoch in range(start_epoch, end_epoch):
        print('Epoch {} of {}'.format(epoch, end_epoch - 1))
        adjust_learning_rate(optimizer,
                             epoch,
                             learning_rate,
                             decay=learning_rate_decay,
                             freq=learning_rate_freq)

        # train for one epoch
        _ = train_one_epoch(train_loader,
                            model,
                            criterion,
                            optimizer,
                            epoch,
                            tracking=train_results_file)

        # evaluate on validation set
        prec1, _ = validate(val_loader,
                            model,
                            criterion,
                            epoch,
                            tracking=test_results_file)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        last_epoch = epoch == (end_epoch - 1)
        if is_best or checkpoint == 'all' or (checkpoint == 'last'
                                              and last_epoch):
            state = {
                'epoch': epoch,
                'arch': arch,
                'state_dict':
                (model.module if use_cuda else model).state_dict(),
                'prec1': prec1,
                'optimizer': optimizer.state_dict(),
            }
            if is_best:
                print('New best model!')
                filename = os.path.join(run_dir, 'checkpoint_best_model.t7')
                print(f'=> saving checkpoint to {filename}')
                torch.save(state, filename)
                best_prec1 = prec1
            if checkpoint == 'all' or (checkpoint == 'last' and last_epoch):
                filename = os.path.join(run_dir, f'checkpoint_{epoch}.t7')
                print(f'=> saving checkpoint to {filename}')
                torch.save(state, filename)