Exemple #1
0
def get_ema(model, hps):
    mu = hps.mu or (1. - (hps.bs * hps.ngpus / 8.) / 1000)
    ema = None
    if hps.ema and hps.train:
        if hps.cpu_ema:
            if dist.get_rank() == 0:
                print("Using CPU EMA")
            ema = CPUEMA(model.parameters(), mu=mu, freq=hps.cpu_ema_freq)
        elif hps.ema_fused:
            ema = FusedEMA(model.parameters(), mu=mu)
        else:
            ema = EMA(model.parameters(), mu=mu)
    return ema
Exemple #2
0
def main():
    # Retrieve config file
    p = create_config(args.config_env, args.config_exp)
    print(colored(p, 'red'))

    # Get model
    print(colored('Retrieve model', 'blue'))
    model = get_model(p, p['scan_model'])
    print(model)
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # Get criterion
    print(colored('Get loss', 'blue'))
    criterion = get_criterion(p)
    criterion.cuda()
    print(criterion)

    # CUDNN
    print(colored('Set CuDNN benchmark', 'blue')) 
    torch.backends.cudnn.benchmark = True

    # Optimizer
    print(colored('Retrieve optimizer', 'blue'))
    optimizer = get_optimizer(p, model)
    print(optimizer)

    # Dataset
    print(colored('Retrieve dataset', 'blue'))
    
    # Transforms 
    strong_transforms = get_train_transformations(p)
    val_transforms = get_val_transformations(p)
    train_dataset = get_train_dataset(p, {'standard': val_transforms, 'augment': strong_transforms},
                                        split='train', to_augmented_dataset=True) 
    train_dataloader = get_train_dataloader(p, train_dataset)
    val_dataset = get_val_dataset(p, val_transforms) 
    val_dataloader = get_val_dataloader(p, val_dataset)
    print(colored('Train samples %d - Val samples %d' %(len(train_dataset), len(val_dataset)), 'yellow'))

    # Checkpoint
    if os.path.exists(p['selflabel_checkpoint']):
        print(colored('Restart from checkpoint {}'.format(p['selflabel_checkpoint']), 'blue'))
        checkpoint = torch.load(p['selflabel_checkpoint'], map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])        
        start_epoch = checkpoint['epoch']

    else:
        print(colored('No checkpoint file at {}'.format(p['selflabel_checkpoint']), 'blue'))
        start_epoch = 0

    # EMA
    if p['use_ema']:
        ema = EMA(model, alpha=p['ema_alpha'])
    else:
        ema = None

    # Main loop
    print(colored('Starting main loop', 'blue'))
    
    for epoch in range(start_epoch, p['epochs']):
        print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow'))
        print(colored('-'*10, 'yellow'))

        # Adjust lr
        lr = adjust_learning_rate(p, optimizer, epoch)
        print('Adjusted learning rate to {:.5f}'.format(lr))

        # Perform self-labeling 
        print('Train ...')
        selflabel_train(train_dataloader, model, criterion, optimizer, epoch, ema=ema)

        # Evaluate (To monitor progress - Not for validation)
        print('Evaluate ...')
        predictions = get_predictions(p, val_dataloader, model)
        clustering_stats = hungarian_evaluate(0, predictions, compute_confusion_matrix=False) 
        print(clustering_stats)
        
        # Checkpoint
        print('Checkpoint ...')
        torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 
                    'epoch': epoch + 1}, p['selflabel_checkpoint'])
        #torch.save(model.module.state_dict(), p['selflabel_model'])
        torch.save(model.module.state_dict(), os.path.join(p['selflabel_dir'], 'model_%d.pth.tar' %(epoch)))
    
    # Evaluate and save the final model
    print(colored('Evaluate model at the end', 'blue'))
    predictions = get_predictions(p, val_dataloader, model)
    clustering_stats = hungarian_evaluate(0, predictions, 
                                class_names=val_dataset.classes,
                                compute_confusion_matrix=True,
                                confusion_matrix_file=os.path.join(p['selflabel_dir'], 'confusion_matrix.png'))
    print(clustering_stats)
    torch.save(model.module.state_dict(), p['selflabel_model'])
def main_worker(gpu, ngpus_per_node, args, num_classes, labeled_data,
                labeled_targets, transform_labeled, unlabeled_data,
                unlabeled_targets, transform_unlabeled, val_data, val_targets,
                transform_val, test_data, test_targets, transform_test):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        torch.distributed.init_process_group(backend=args.dist_backend,
                                             init_method=args.dist_url,
                                             world_size=args.world_size,
                                             rank=args.rank)

    model = init_model(args, num_classes)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.train_batch = int(args.train_batch / ngpus_per_node)
            args.n_imgs_per_epoch = int(args.n_imgs_per_epoch / ngpus_per_node)
            args.test_batch = int(args.test_batch / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                model)  # sync BN layers
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            warnings.warn(
                'NOT DISTRIBUTED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        warnings.warn(
            'NOT DISTRIBUTED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        warnings.warn(
            'NOT DISTRIBUTED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    print('Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    per_epoch_steps = args.n_imgs_per_epoch // args.train_batch
    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        # if len(param.size()) == 1:
        if 'bn' in name or 'bias' in name:
            non_wd_params.append(
                param)  # bn.weight, bn.bias and classifier.bias, conv2d.bias
            # print(name)
        else:
            wd_params.append(param)
    param_list = [{
        'params': wd_params,
        'weight_decay': args.weight_decay
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]
    optimizer = optim.SGD(param_list,
                          lr=args.lr,
                          momentum=args.momentum,
                          nesterov=args.nesterov)
    total_steps = args.epochs * per_epoch_steps
    scheduler = get_cosine_schedule_with_warmup(optimizer, 0, total_steps)
    cudnn.benchmark = True

    labeledset = CIFAR_Semi(labeled_data,
                            labeled_targets,
                            transform=transform_labeled)
    unlabeledset = CIFAR_Semi(unlabeled_data,
                              unlabeled_targets,
                              transform=transform_unlabeled)
    valset = CIFAR_Semi(val_data, val_targets, transform=transform_val)
    testset = CIFAR_Semi(test_data, test_targets, transform=transform_test)

    if args.distributed:
        labeled_sampler = DistributedSampler(labeledset,
                                             num_samples=per_epoch_steps *
                                             args.train_batch * ngpus_per_node)
        unlabeled_sampler = DistributedSampler(
            unlabeledset,
            num_samples=per_epoch_steps * args.train_batch * ngpus_per_node *
            args.mu)
    else:
        labeled_sampler = None
        unlabeled_sampler = None
    labeledloader = DataLoader(labeledset,
                               batch_size=args.train_batch,
                               shuffle=(labeled_sampler is None),
                               num_workers=args.workers,
                               pin_memory=True,
                               sampler=labeled_sampler)
    unlabeledloader = DataLoader(unlabeledset,
                                 batch_size=args.train_batch * args.mu,
                                 shuffle=(unlabeled_sampler is None),
                                 num_workers=args.workers,
                                 pin_memory=True,
                                 sampler=unlabeled_sampler)

    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=args.test_batch,
                                            shuffle=False,
                                            num_workers=args.workers,
                                            drop_last=False)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             drop_last=False)

    best_val_acc = 0
    best_test_acc = 0
    start_epoch = 0
    if args.use_ema:  # everybody ema, but only the rank 0 save model
        ema_model = EMA(model, args.ema_decay)

    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_test_acc = checkpoint['best_test_acc']
        best_val_acc = checkpoint['best_val_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        if args.use_ema:
            ema_model.load_state_dict(checkpoint['ema_state_dict'])
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
    else:
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title)
            logger.set_names([
                'Learning Rate', 'Train Loss', 'Train Loss X', 'Train Loss U',
                'Mask', 'Total Acc', 'Used Acc', 'Valid Loss', 'Valid Acc.',
                'Test Loss', 'Test Acc.'
            ])

    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            labeled_sampler.set_epoch(epoch)
            unlabeled_sampler.set_epoch(epoch)
        lr = optimizer.param_groups[0]['lr']
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr))
        loss, loss_x, loss_u, mask_prob, total_c, used_c = train(
            labeledloader, unlabeledloader, model,
            ema_model if args.use_ema else None, optimizer, scheduler, epoch,
            args)

        if args.use_ema:
            ema_model.apply_shadow()
        val_loss, val_acc = test(valloader, model, epoch, args)
        test_loss, test_acc = test(testloader, model, epoch, args)
        if args.use_ema:
            ema_model.restore()

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            logger.append([
                lr, loss, loss_x, loss_u, mask_prob, total_c, used_c, val_loss,
                val_acc, test_loss, test_acc
            ])
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'ema_state_dict':
                    ema_model.shadow if args.use_ema else None,
                    'best_val_acc': best_val_acc,
                    'best_test_acc': best_test_acc,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                },
                checkpoint=args.checkpoint,
                filename='checkpoint.pth.tar')
        best_val_acc = max(val_acc, best_val_acc)
        best_test_acc = max(test_acc, best_test_acc)
    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        logger.close()
    print('Best test acc:', best_test_acc)
Exemple #4
0
    # Checkpoint
    if os.path.exists(p['selflabel_checkpoint']):
        print(colored('Restart from checkpoint {}'.format(p['selflabel_checkpoint']), 'blue'))
        checkpoint = torch.load(p['selflabel_checkpoint'], map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])        
        start_epoch = checkpoint['epoch']

    else:
        print(colored('No checkpoint file at {}'.format(p['selflabel_checkpoint']), 'blue'))
        start_epoch = 0

    # EMA
    if p['use_ema']:
        ema = EMA(model, alpha=p['ema_alpha'])
    else:
        ema = None

    # Main loop
    print(colored('Starting main loop', 'blue'))
    
    for epoch in range(start_epoch, p['epochs']):
        print(colored('Epoch %d/%d' %(epoch+1, p['epochs']), 'yellow'))
        print(colored('-'*10, 'yellow'))

        # Adjust lr
        lr = adjust_learning_rate(p, optimizer, epoch)
        print('Adjusted learning rate to {:.5f}'.format(lr))

        # Perform self-labeling 
Exemple #5
0
def train_model_from_scratch(model, logger_name, checkpoint_name, use_cuda,
                             labeled_data, labeled_targets, tfms_labeled,
                             unlabeled_data, unlabeled_targets, tfms_unlabeled,
                             val_data, val_targets, tfms_val, test_data,
                             test_targets, tfms_test):
    """
    This function trains a pre-defined model from scratch and test and log the info.
    The training scheme is defined in args.
    """
    # define the trainloader, valloader, testloader
    labeledset = CIFAR_Semi(labeled_data,
                            labeled_targets,
                            transform=tfms_labeled)
    unlabeledset = CIFAR_Semi(unlabeled_data,
                              unlabeled_targets,
                              transform=tfms_unlabeled)
    valset = CIFAR_Semi(val_data, val_targets, transform=tfms_val)
    testset = CIFAR_Semi(test_data, test_targets, transform=tfms_test)

    per_epoch_steps = args.n_imgs_per_epoch // args.train_batch

    sampler_x = RandomSampler(labeledset,
                              replacement=True,
                              num_samples=per_epoch_steps * args.train_batch)
    batch_sampler_x = BatchSampler(sampler_x,
                                   batch_size=args.train_batch,
                                   drop_last=True)
    labeledloader = DataLoader(labeledset,
                               batch_sampler=batch_sampler_x,
                               num_workers=args.workers)

    sampler_u = RandomSampler(unlabeledset,
                              replacement=True,
                              num_samples=per_epoch_steps * args.train_batch *
                              args.mu)
    batch_sampler_u = BatchSampler(sampler_u,
                                   batch_size=args.train_batch * args.mu,
                                   drop_last=True)
    unlabeledloader = DataLoader(unlabeledset,
                                 batch_sampler=batch_sampler_u,
                                 num_workers=args.workers)

    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=args.test_batch,
                                            shuffle=False,
                                            num_workers=args.workers,
                                            drop_last=False)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             drop_last=False)

    # define optimizer and learning rate scheduler
    model = torch.nn.DataParallel(model).cuda()
    print('Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        # if len(param.size()) == 1:
        if 'bn' in name or 'bias' in name:
            non_wd_params.append(
                param)  # bn.weight, bn.bias and classifier.bias, conv2d.bias
            # print(name)
        else:
            wd_params.append(param)
    param_list = [{
        'params': wd_params,
        'weight_decay': args.weight_decay
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]
    optimizer = optim.SGD(param_list,
                          lr=args.lr,
                          momentum=args.momentum,
                          nesterov=args.nesterov)
    total_steps = args.epochs * per_epoch_steps
    warmup_steps = args.warmup * per_epoch_steps
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps,
                                                total_steps)

    # train the model from scratch
    best_val_acc = 0
    best_test_acc = 0
    start_epoch = 0
    if args.use_ema:
        ema_model = EMA(model, args.ema_decay)
    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_test_acc = checkpoint['best_test_acc']
        best_val_acc = checkpoint['best_val_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        if args.use_ema:
            ema_model.load_state_dict(checkpoint['ema_state_dict'])
        logger = Logger(os.path.join(args.checkpoint, logger_name),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, logger_name),
                        title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Train Loss X', 'Train Loss U',
            'Mask', 'Total Acc', 'Used Acc', 'Valid Loss', 'Valid Acc.',
            'Test Loss', 'Test Acc.'
        ])

    for epoch in range(start_epoch, args.epochs):
        lr = optimizer.param_groups[0]['lr']
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr))
        loss, loss_x, loss_u, mask_prob, total_c, used_c = train(
            labeledloader, unlabeledloader, model,
            ema_model if args.use_ema else None, optimizer, scheduler, epoch,
            use_cuda)

        if args.use_ema:
            ema_model.apply_shadow()
        val_loss, val_acc = test(valloader, model, epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, epoch, use_cuda)
        if args.use_ema:
            ema_model.restore()

        logger.append([
            lr, loss, loss_x, loss_u, mask_prob, total_c, used_c, val_loss,
            val_acc, test_loss, test_acc
        ])

        is_best = val_acc > best_val_acc
        if is_best:
            best_test_acc = test_acc
            best_model = copy.deepcopy(model)
        if checkpoint_name is not None:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'ema_state_dict':
                    ema_model.shadow if args.use_ema else None,
                    'best_val_acc': best_val_acc,
                    'best_test_acc': best_test_acc,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                },
                checkpoint=args.checkpoint,
                filename=checkpoint_name)
        best_val_acc = max(val_acc, best_val_acc)
    logger.close()
    print('Best test acc:', best_test_acc)
    return best_model, best_test_acc