def main():
    args = parser.parse_args()
    logger = get_logger(args.logging_file)
    logger.info(args)
    args.save_dir = os.path.join(os.getcwd(), args.save_dir)
    check_dir(args.save_dir)

    assert args.world_size >= 1

    args.classes = 1000
    args.num_training_samples = 1281167
    args.world = args.rank
    ngpus_per_node = torch.cuda.device_count()
    args.world_size = ngpus_per_node * args.world_size
    args.mix_precision_training = True if args.dtype == 'float16' else False
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
Ejemplo n.º 2
0
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import logging

from scripts import linkageList
from scripts import reclusterTree
from scripts.utils import get_logger

logger = get_logger(level=logging.INFO)


def heat_dendrogram(
    truthJet=None,
    recluster_jet1=None,
    recluster_jet2=None,
    full_path=False,
    FigName=None,
):
    """
	Create  a heat dendrogram clustermap.

	Args:
	:param truthJet: Truth jet dictionary
	:param recluster_jet1: reclustered jet 1
	:param recluster_jet2: reclustered jet 2
	:param full_path: Bool. If True, then use the total number of steps to connect a pair of leaves as the heat data. If False, then Given a pair of jet constituents {i,j} and the number of steps needed for each constituent to reach their closest common ancestor {Si,Sj}, the heat map scale represents the maximum number of steps, i.e. max{Si,Sj}.
	:param FigName: Dir and location to save a plot.
	"""

    # Build truth jet heat data
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    logger = get_logger(args.logging_file)
    logger.info("Use GPU: {} for training".format(args.gpu))

    args.rank = args.rank * ngpus_per_node + gpu
    torch.distributed.init_process_group(backend="nccl",
                                         init_method=args.dist_url,
                                         world_size=args.world_size,
                                         rank=args.rank)

    epochs = args.epochs
    input_size = args.input_size
    resume_epoch = args.resume_epoch
    initializer = KaimingInitializer()
    zero_gamma = ZeroLastGamma()
    mix_precision_training = args.mix_precision_training
    is_first_rank = True if args.rank % ngpus_per_node == 0 else False

    batches_pre_epoch = args.num_training_samples // (args.batch_size *
                                                      ngpus_per_node)
    lr = 0.1 * (args.batch_size * ngpus_per_node //
                32) if args.lr == 0 else args.lr

    model = get_model(models, args.model)

    model.apply(initializer)
    if args.last_gamma:
        model.apply(zero_gamma)
        logger.info('Apply zero last gamma init.')

    if is_first_rank and args.model_info:
        summary(model, torch.rand((1, 3, input_size, input_size)))

    parameters = model.parameters() if not args.no_wd else no_decay_bias(model)
    if args.sgd_gc:
        logger.info('Use SGD_GC optimizer.')
        optimizer = SGD_GC(parameters,
                           lr=lr,
                           momentum=args.momentum,
                           weight_decay=args.wd,
                           nesterov=True)
    else:
        optimizer = optim.SGD(parameters,
                              lr=lr,
                              momentum=args.momentum,
                              weight_decay=args.wd,
                              nesterov=True)

    lr_scheduler = CosineWarmupLr(optimizer,
                                  batches_pre_epoch,
                                  epochs,
                                  base_lr=args.lr,
                                  warmup_epochs=args.warmup_epochs)

    # dropblock_scheduler = DropBlockScheduler(model, batches_pre_epoch, epochs)

    if args.lookahead:
        optimizer = Lookahead(optimizer)
        logger.info('Use lookahead optimizer.')

    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)
    args.num_workers = int(
        (args.num_workers + ngpus_per_node - 1) / ngpus_per_node)

    if args.mix_precision_training and is_first_rank:
        logger.info('Train with FP16.')

    scaler = GradScaler(enabled=args.mix_precision_training)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    Loss = nn.CrossEntropyLoss().cuda(args.gpu) if not args.label_smoothing else \
        LabelSmoothingLoss(args.classes, smoothing=0.1).cuda(args.gpu)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if args.autoaugment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            ImageNetPolicy,
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            # Cutout(),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.ToTensor(),
            normalize,
        ])

    val_transform = transforms.Compose([
        transforms.Resize(int(input_size / 0.875)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        normalize,
    ])

    train_set = ImageNet(args.data_path,
                         split='train',
                         transform=train_transform)
    val_set = ImageNet(args.data_path, split='val', transform=val_transform)

    train_sampler = DistributedSampler(train_set)
    train_loader = DataLoader(train_set,
                              args.batch_size,
                              False,
                              pin_memory=True,
                              num_workers=args.num_workers,
                              drop_last=True,
                              sampler=train_sampler)
    val_loader = DataLoader(val_set,
                            args.batch_size,
                            False,
                            pin_memory=True,
                            num_workers=args.num_workers,
                            drop_last=False)

    if resume_epoch > 0:
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.resume_param, map_location=loc)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scaler.load_state_dict(checkpoint['scaler'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        print("Finish loading resume param.")

    torch.backends.cudnn.benchmark = True

    top1_acc = metric.Accuracy(name='Top1 Accuracy')
    top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy')
    loss_record = metric.NumericalCost(name='Loss')

    for epoch in range(resume_epoch, epochs):
        tic = time.time()
        train_sampler.set_epoch(epoch)
        if not args.mixup:
            train_one_epoch(model, train_loader, Loss, optimizer, epoch,
                            lr_scheduler, logger, top1_acc, loss_record,
                            scaler, args)
        else:
            train_one_epoch_mixup(model, train_loader, Loss, optimizer, epoch,
                                  lr_scheduler, logger, loss_record, scaler,
                                  args)
        train_speed = int(args.num_training_samples // (time.time() - tic))
        if is_first_rank:
            logger.info(
                'Finish one epoch speed: {} samples/s'.format(train_speed))
        test(model, val_loader, Loss, epoch, logger, top1_acc, top5_acc,
             loss_record, args)

        if args.rank % ngpus_per_node == 0:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scaler': scaler.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
            }
            torch.save(
                checkpoint, '{}/{}_{}_{:.5}.pt'.format(args.save_dir,
                                                       args.model, epoch,
                                                       top1_acc.get()))