Beispiel #1
0
def load_model(hparams):
    model = Tacotron2(hparams).cuda()
    model = batchnorm_to_float(model.half()) if hparams.fp16_run else model

    if hparams.distributed_run:
        model = DistributedDataParallel(model)
    elif torch.cuda.device_count() > 1:
        model = DataParallel(model)

    return model
Beispiel #2
0
def load_model(hparams):
    model = Tacotron2(hparams).cuda()
    if hparams.fp16_run:
        model = batchnorm_to_float(model.half())
        model.decoder.attention_layer.score_mask_value = float(
            finfo('float16').min)
    if hparams.distributed_run:
        model = DistributedDataParallel(model)
    elif torch.cuda.device_count() > 1:
        model = DataParallel(model)
    return model
Beispiel #3
0
def main():
    print("~~epoch\thours\ttop1Accuracy\n")
    start_time = datetime.now()
    args.distributed = args.world_size > 1
    args.gpu = 0
    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    # create model
    if args.pretrained: model = models.__dict__[args.arch](pretrained=True)
    else: model = models.__dict__[args.arch]()

    model = model.cuda()
    n_dev = torch.cuda.device_count()
    if args.fp16: model = network_to_half(model)
    if args.distributed: model = DDP(model)
    elif args.dp:
        model = nn.DataParallel(model)
        args.batch_size *= n_dev

    global param_copy
    if args.fp16:
        param_copy = [
            param.clone().type(torch.cuda.FloatTensor).detach()
            for param in model.parameters()
        ]
        for param in param_copy:
            param.requires_grad = True
    else:
        param_copy = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(param_copy,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    best_prec1 = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    train_loader, val_loader, train_sampler = get_loaders(traindir, valdir)

    if args.evaluate:
        return validate(val_loader, model, criterion, epoch, start_time)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed: train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)
        if epoch == args.epochs - 6:
            args.sz = 288
            args.batch_size = 128
            train_loader, val_loader, train_sampler, val_sampler = get_loaders(
                traindir, valdir, use_val_sampler=False, min_scale=0.5)

        if args.distributed:
            train_sampler.set_epoch(epoch)
            val_sampler.set_epoch(epoch)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            train(train_loader, model, criterion, optimizer, epoch)

        if args.prof: break
        prec1 = validate(val_loader, model, criterion, epoch, start_time)

        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Beispiel #4
0
def train300_mlperf_coco(args):
    args.distributed = args.world_size > 1

    from coco import COCO
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    train_trans = SSDTransformer(dboxes, (300, 300), val=False)
    val_trans = SSDTransformer(dboxes, (300, 300), val=True)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    cocoGt = COCO(annotation_file=val_annotate)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    train_coco = COCODetection(train_coco_root, train_annotate, train_trans)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_coco)
    else:
        train_sampler = None
    train_dataloader = DataLoader(train_coco,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4,
                                  sampler=train_sampler)

    ssd300 = SSD300(train_coco.labelnum)
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint)
        ssd300.load_state_dict(od["model"])
    ssd300.train()
    if use_cuda:
        ssd300.cuda()
    loss_func = Loss(dboxes)
    if use_cuda:
        loss_func.cuda()
    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)
        ssd300 = DistributedDataParallel(ssd300)
    else:
        ssd300 = torch.nn.DataParallel(ssd300)

    optim = torch.optim.SGD(ssd300.parameters(),
                            lr=1e-3,
                            momentum=0.9,
                            weight_decay=5e-4)
    print("epoch", "nbatch", "loss")

    iter_num = args.iteration
    avg_loss = 0.0
    inv_map = {v: k for k, v in val_coco.label_map.items()}

    for epoch in range(args.epochs):

        for nbatch, (img, img_size, bbox,
                     label) in enumerate(train_dataloader):

            start = time.time()
            if iter_num == 160000:
                print("")
                print("lr decay step #1")
                for param_group in optim.param_groups:
                    param_group['lr'] = 1e-4

            if iter_num == 200000:
                print("")
                print("lr decay step #2")
                for param_group in optim.param_groups:
                    param_group['lr'] = 1e-5

            if use_cuda:
                img = img.cuda()
            img = Variable(img, requires_grad=True)
            ploc, plabel = ssd300(img)
            trans_bbox = bbox.transpose(1, 2).contiguous()
            if use_cuda:
                trans_bbox = trans_bbox.cuda()
                label = label.cuda()
            gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                           Variable(label, requires_grad=False)
            loss = loss_func(ploc, plabel, gloc, glabel)

            if not np.isinf(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()

            optim.zero_grad()
            loss.backward()
            optim.step()
            end = time.time()

            if nbatch % 10 == 0:
                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, Average time: {:.3f} secs"\
                            .format(iter_num, loss.item(), avg_loss, end - start))

            if iter_num in args.evaluation:
                if not args.no_save:
                    print("")
                    print("saving model...")
                    torch.save(
                        {
                            "model": ssd300.state_dict(),
                            "label_map": train_coco.label_info
                        }, "./models/iter_{}.pt".format(iter_num))

                if coco_eval(ssd300, val_coco, cocoGt, encoder, inv_map,
                             args.threshold):
                    return

            iter_num += 1
Beispiel #5
0
def main():
    global best_prec1, args

    args.distributed = args.world_size > 1
    #    args.gpu = 0
    if args.distributed:
        #       args.gpu = args.rank % torch.cuda.device_count()
        #      torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    n_dev = torch.cuda.device_count()
    if args.fp16: model = network_to_half(model)
    if args.distributed:
        model = DDP(model)
        #args.lr *= n_dev
    elif args.dp:
        model = nn.DataParallel(model)
        args.batch_size *= n_dev
        #args.lr *= n_dev

    global param_copy
    if args.fp16:
        param_copy = [
            param.clone().type(torch.cuda.FloatTensor).detach()
            for param in model.parameters()
        ]
        for param in param_copy:
            param.requires_grad = True
    else:
        param_copy = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(param_copy,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(args.sz),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = (
        torch.utils.data.distributed.DistributedSampler(train_dataset)
        if args.distributed else None)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(int(args.sz * 1.14)),
            transforms.CenterCrop(args.sz),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed: train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof: break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Beispiel #6
0
def main():
    start_time = datetime.now()
    args.distributed = True  #args.world_size > 1
    args.gpu = 0
    if args.distributed:
        import socket
        args.gpu = args.rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        logger.info('| distributed init (rank {}): {}'.format(
            args.rank, args.distributed_init_method))
        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.distributed_init_method,
            world_size=args.world_size,
            rank=args.rank,
        )
        logger.info('| initialized host {} as rank {}'.format(
            socket.gethostname(), args.rank))
        #args.gpu = args.rank % torch.cuda.device_count()
        #torch.cuda.set_device(args.gpu)
        #logger.info('initializing...')
        #dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size)
        #logger.info('initialized')

    # create model
    if args.pretrained: model = models.__dict__[args.arch](pretrained=True)
    else:
        model = models.__dict__[args.arch](
            num_structured_layers=args.num_structured_layers,
            structure_type=args.structure_type,
            nblocks=args.nblocks,
            param=args.param)
    model = model.cuda()
    n_dev = torch.cuda.device_count()
    logger.info('Created model')
    if args.distributed: model = DDP(model)
    elif args.dp:
        model = nn.DataParallel(model)
        args.batch_size *= n_dev
    logger.info('Set up data parallel')

    global structured_params
    global unstructured_params
    structured_params = filter(
        lambda p: hasattr(p, '_is_structured') and p._is_structured,
        model.parameters())
    unstructured_params = filter(
        lambda p: not (hasattr(p, '_is_structured') and p._is_structured),
        model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD([{
        'params': structured_params,
        'weight_decay': 0.0
    }, {
        'params': unstructured_params
    }],
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    logger.info('Created optimizer')
    best_acc1 = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    if args.small:
        traindir = os.path.join(args.data + '-sz/160', 'train')
        valdir = os.path.join(args.data + '-sz/160', 'val')
        args.sz = 128
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        args.sz = 224

    train_loader, val_loader, train_sampler, val_sampler = get_loaders(
        traindir, valdir, use_val_sampler=True)
    logger.info('Loaded data')
    if args.evaluate:
        return validate(val_loader, model, criterion, epoch, start_time)

    logger.info(model)
    logger.info('| model {}, criterion {}'.format(
        args.arch, criterion.__class__.__name__))
    logger.info('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
    for epoch in range(args.start_epoch, args.epochs):
        logger.info(f'Epoch {epoch}')
        adjust_learning_rate(optimizer, epoch)
        if epoch == int(args.epochs * 0.4 + 0.5):
            traindir = os.path.join(args.data, 'train')
            valdir = os.path.join(args.data, 'val')
            args.sz = 224
            train_loader, val_loader, train_sampler, val_sampler = get_loaders(
                traindir, valdir)
        if epoch == int(args.epochs * 0.92 + 0.5):
            args.sz = 288
            args.batch_size = 128
            train_loader, val_loader, train_sampler, val_sampler = get_loaders(
                traindir, valdir, use_val_sampler=False, min_scale=0.5)

        if args.distributed:
            train_sampler.set_epoch(epoch)
            val_sampler.set_epoch(epoch)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            train(train_loader, model, criterion, optimizer, epoch)

        if args.prof: break
        acc1 = validate(val_loader, model, criterion, epoch, start_time)

        if args.rank == 0:
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Beispiel #7
0
def main():
    global best_prec1, args

    args.distributed = args.world_size > 1
    args.gpu = 0
    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    n_dev = torch.cuda.device_count()
    if args.fp16: model = network_to_half(model)
    if args.distributed:
        model = DDP(model)
        #args.lr *= n_dev
    elif args.dp:
        model = nn.DataParallel(model)
        args.batch_size *= n_dev
        #args.lr *= n_dev

    global param_copy
    if args.fp16:
        param_copy = [param.clone().type(torch.cuda.FloatTensor).detach() for param in model.parameters()]
        for param in param_copy: param.requires_grad = True
    else: param_copy = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(param_copy, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else: print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(args.sz),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = (torch.utils.data.distributed.DistributedSampler(train_dataset)
                     if args.distributed else None)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(int(args.sz*1.14)),
            transforms.CenterCrop(args.sz),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed: train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof: break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
Beispiel #8
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not args.distributed and args.cuda:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
        else:
            model = torch.nn.DataParallel(model)
    elif args.distributed:
        model = DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        """
def main():
    global args, folder_save
    args = parser.parse_args()
    args.distributed = args.world_size > 1
    args.gpu = 0

    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()

    print(args)

    opts = vars(args)

    name_log = ''.join('{}{}-'.format(key, val)
                       for key, val in sorted(opts.items())
                       if key is not 'rank')
    name_log = name_log.replace('/', '-')
    name_log = name_log.replace('[', '-')
    name_log = name_log.replace(']', '-')

    name_log_list = list(map(''.join, zip(*[iter(name_log)] * 100)))

    print(name_log_list, '\n')

    folder_save = args.save_folder
    for i in range(len(name_log_list)):
        folder_save = os.path.join(folder_save, name_log_list[i])

        if not os.path.isdir(folder_save):
            os.mkdir(folder_save)

    print('This will be saved in: ' + folder_save, '\n')

    args.bottleneck_width = json.loads(args.bottleneck_width)
    args.bottleneck_depth = json.loads(args.bottleneck_depth)

    if args.distributed:
        torch.cuda.set_device(args.rank % torch.cuda.device_count())
        torch.cuda.set_device(args.gpu)

    global best_prec1
    global scat
    scat = Scattering(M=224, N=224, J=args.J, pre_pad=False).cuda()

    def save_checkpoint(state,
                        is_best,
                        filename=os.path.join(folder_save,
                                              'checkpoint.pth.tar')):
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            os.path.join(folder_save, 'model_best.pth.tar'))

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    model = models.__dict__[args.arch](224,
                                       args.J,
                                       width=args.bottleneck_width,
                                       depth=args.bottleneck_depth,
                                       conv1x1=args.bottleneck_conv1x1)
    model.cuda()

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    print('Number of parameters: %d' % params)
    #### MODIFIED by Edouard

    save_checkpoint(
        {
            'epoch': -1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': 0,
        }, False)

    if args.distributed:
        model = DDP(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)