def reset(self):
        '''mutable can be only initialized for once, hence it needs to
        reset model, optimizer, scheduler when run a new trial.
        '''
        # model
        self.model = build_model(self.cfg)
        self.model.to(self.device)
        self.logger.info(f"Building model {self.cfg.model.name} ...")

        # load teacher model if using knowledge distillation
        if hasattr(self.cfg, 'kd') and self.cfg.kd.enable:
            self.kd_model = load_kd_model(self.cfg).to(self.device)
            self.kd_model.eval()
            self.logger.info(
                f"Building teacher model {self.cfg.kd.model.name} ...")
        else:
            self.kd_model = None

        # optimizer
        self.optimizer = generate_optimizer(
            model=self.model,
            optim_name=self.cfg.optim.name,
            lr=self.cfg.optim.base_lr,
            momentum=self.cfg.optim.momentum,
            weight_decay=self.cfg.optim.weight_decay)
        self.logger.info(f"Building optimizer {self.cfg.optim.name} ...")

        # scheduler
        self.scheduler_params = parse_cfg_for_scheduler(
            self.cfg, self.cfg.optim.scheduler.name)
        self.lr_scheduler = generate_scheduler(self.optimizer,
                                               self.cfg.optim.scheduler.name,
                                               **self.scheduler_params)
        self.logger.info(
            f"Building optim.scheduler {self.cfg.optim.scheduler.name} ...")
    def set_up(self):
        # model
        self.model = build_model(self.cfg)
        self.logger.info(f"Building model {self.cfg.model.name} ...")

        # mutator
        # self.logger.info('Cell choices: {}'.format(model.layers[0].nodes[0].cell_x.op_choice.choices))
        self.mutator = build_mutator(self.model, self.cfg)
        for x in self.mutator.mutables:
            if isinstance(x, nni.nas.pytorch.mutables.LayerChoice):
                self.logger.info('Cell choices: {}'.format(x.choices))
                break

        self.logger.info(f"Building mutator {self.cfg.mutator.name} ...")

        # dataset
        self.batch_size = self.cfg.dataset.batch_size
        self.workers = self.cfg.dataset.workers
        self.dataset_train, self.dataset_valid = build_dataset(self.cfg)
        self.logger.info(f"Building dataset {self.cfg.dataset.name} ...")

        # loss
        self.loss = build_loss_fn(self.cfg)
        self.logger.info(f"Building loss function {self.cfg.loss.name} ...")

        # optimizer
        self.optimizer = generate_optimizer(
            model=self.model,
            optim_name=self.cfg.optim.name,
            lr=self.cfg.optim.base_lr,
            momentum=self.cfg.optim.momentum,
            weight_decay=self.cfg.optim.weight_decay)
        self.logger.info(f"Building optimizer {self.cfg.optim.name} ...")

        # scheduler
        self.scheduler_params = parse_cfg_for_scheduler(
            self.cfg, self.cfg.optim.scheduler.name)
        self.lr_scheduler = generate_scheduler(self.optimizer,
                                               self.cfg.optim.scheduler.name,
                                               **self.scheduler_params)
        self.logger.info(
            f"Building optimizer scheduler {self.cfg.optim.scheduler.name} ..."
        )

        # miscellaneous
        self.num_epochs = self.cfg.trainer.num_epochs
        self.log_frequency = self.cfg.logger.log_frequency
        self.start_epoch = 0
Exemplo n.º 3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # create model
    model = build_model(args.arch, args)

    if args.weights is not None:
        print("=> using saved weights [%s]"%args.weights)
        weights = torch.load(args.weights)

        # new_weights_sd = {}
        # for key in weights['state_dict']:
        #     new_weights_sd[key[7:]] = weights['state_dict'][key]
        # weights['state_dict'] = new_weights_sd
        
        model.load_state_dict(weights['state_dict'])

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    cudnn.benchmark = True

    # Data loading code
    valdir = os.path.join(args.data, 'val')
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.size == 224:
        l_size = 256
        s_size = 224
    elif args.size == 128:
        l_size = 174
        s_size = 128

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    crop_size = s_size
    args.batch_size = args.batch_size

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(l_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    validate_shift(val_loader, model, args)
Exemplo n.º 4
0
X_test, Y_test = preprocessing.create_dataset(test, look_back)

# reshape input to be [samples, time steps, features]
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))
train_shape = X_train.shape

# assert False

X_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))

X_train = X_train.shuffle(10000).batch(64, drop_remainder=True)

model = networks.build_model(train_shape,
                             neurons=64,
                             layers=6,
                             dropout_rate=0.0,
                             train=True,
                             batch_size=64)

model.fit(X_train,
          epochs=20,
          verbose=1,
          shuffle=False,
          callbacks=[networks.ResetModelCallback()])

pred_model = networks.build_model(train_shape,
                                  neurons=64,
                                  layers=6,
                                  dropout_rate=0.0)
pred_model.set_weights(model.get_weights())
Exemplo n.º 5
0
    args = parser.parse_args()
    config_file = args.config_file
    if os.path.isdir(args.arc_path) and args.arc_path[-1] != '/':
        args.arc_path += '/'
    arc_path = args.arc_path

    assert config_file and arc_path, f"please check whether {config_file} and {arc_path} exists"

    # configuration
    cfg = setup_cfg(args)
    with open(os.path.join(cfg.logger.path, 'retrain.yaml'), 'w') as f:
        f.write(str(cfg))
    cfg.update({'args': args})
    logger = MyLogger(__name__, cfg).getlogger()
    logger.info('args:{}'.format(args))

    if args.cam_only:
        model = build_model(cfg)
        apply_fixed_architecture(model, args.arc_path)
        cam = CAM3D(cfg, model)
        cam.run()
    else:
        evaluator = build_evaluator(cfg)
        if os.path.isdir(arc_path):
            best_arch_info = evaluator.compare()
            evaluator.run(best_arch_info['arc'])
        elif os.path.isfile(arc_path):
            evaluator.run(arc_path, validate=True, test=args.test_only)
        else:
            logger.info(f'{arc_path} is invalid.')
Exemplo n.º 6
0
                    default=None,
                    type=str,
                    metavar='PATH',
                    help='path to pretrained model weights')
args = parser.parse_args()

img = Image.open(
    "/home/xueyan/antialias-cnn/data/ILSVRC2012/val/n04228054/ILSVRC2012_val_00000568.JPEG"
)

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

img = F.resize(img, (224, 224), interpolation=2)
img = F.to_tensor(img)
img = (F.normalize(img, mean=mean, std=std)[None, :]).cuda()

model = build_model(args.arch, args).cuda()

if args.weights is not None:
    print("=> using saved weights [%s]" % args.weights)
    weights = torch.load(args.weights)
    new_weights_sd = {}
    for key in weights['state_dict']:
        new_weights_sd[key[7:]] = weights['state_dict'][key]
    weights['state_dict'] = new_weights_sd
    model.load_state_dict(weights['state_dict'])

model.eval()
with torch.no_grad():
    output = model(img)
Exemplo n.º 7
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create log file and timestamp
    log_pth = os.path.join(args.out_dir, 'log.txt')
    os.system('touch ' + log_pth)
    log_file = open(log_pth, 'a')
    log_file.write(str(datetime.now()) + '\n')
    log_file.close()

    # create model
    model = build_model(args.arch, args)

    if args.weights is not None:
        print("=> using saved weights [%s]" % args.weights)
        weights = torch.load(args.weights)

        new_weights_sd = {}
        for key in weights['state_dict']:
            new_weights_sd[key[7:]] = weights['state_dict'][key]
        weights['state_dict'] = new_weights_sd

        if args.num_classes != 1000 and (args.evaluate == False
                                         and args.evaluate_shift == False and
                                         args.evaluate_shift_correct == False
                                         and args.evaluate_diagonal == False
                                         and args.evaluate_save == False):
            model_dict = model.state_dict()

            # pop fc parameters
            new_weights_sd = {}
            for key in weights['state_dict']:
                if 'fc' not in key:
                    new_weights_sd[key] = weights['state_dict'][key]
            model_dict.update(new_weights_sd)
            weights['state_dict'] = model_dict

        model.load_state_dict(weights['state_dict'])
        # model.load_state_dict(weights)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    cudnn.benchmark = True

    # Data loading code
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.size == 224:
        l_size = 256
        s_size = 224
    elif args.size == 128:
        l_size = 174
        s_size = 128

    valdir = os.path.join(args.data, 'val')

    crop_size = l_size if (args.evaluate_shift or args.evaluate_diagonal
                           or args.evaluate_save) else s_size
    args.batch_size = 1 if (args.evaluate_diagonal
                            or args.evaluate_save) else args.batch_size

    if args.dataset == 'imagenet':
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(l_size),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    elif args.dataset == 'vid':
        collator = VIDBatchCollator()
        val_loader = torch.utils.data.DataLoader(VidDataset(
            args.data, False,
            transforms.Compose([
                transforms.Resize(l_size),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ]), args.val_vid_imagenet, args.val_vid_soft),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 collate_fn=collator)
    elif args.dataset == 'vid_robust':
        collator = VIDRobustBatchCollator()
        val_loader = torch.utils.data.DataLoader(VidRobustDataset(
            args.data, False,
            transforms.Compose([
                transforms.Resize(l_size),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ]), args.robust_num),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 collate_fn=collator)
    else:
        assert False, "Not implemented error."

    if args.save_weights is not None:  # "deparallelize" saved weights
        print("=> saving 'deparallelized' weights [%s]" % args.save_weights)
        # TO-DO: automatically save this during training
        if args.gpu is not None:
            torch.save({'state_dict': model.state_dict()}, args.save_weights)
        else:
            if (args.arch[:7] == 'alexnet' or args.arch[:3] == 'vgg'):
                model.features = model.features.module
                torch.save({'state_dict': model.state_dict()},
                           args.save_weights)
            else:
                torch.save({'state_dict': model.module.state_dict()},
                           args.save_weights)
        return

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    if (args.evaluate_shift):
        validate_shift(val_loader, model, args)
        return

    if (args.evaluate_shift_correct):
        validate_shift_correct(val_loader, model, args)
        return
Exemplo n.º 8
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # create log file and timestamp
    log_pth = os.path.join(args.out_dir, 'log.txt')
    os.system('touch ' + log_pth)
    log_file = open(log_pth, 'a')
    log_file.write(str(datetime.now()) + '\n')
    log_file.close()

    # create model
    model = build_model(args.arch, args)

    if args.weights is not None:
        print("=> using saved weights [%s]"%args.weights)
        weights = torch.load(args.weights)

        new_weights_sd = {}
        for key in weights['state_dict']:
            new_weights_sd[key[7:]] = weights['state_dict'][key]
        weights['state_dict'] = new_weights_sd

        if args.num_classes != 1000 and (args.evaluate == False and args.evaluate_shift == False and args.evaluate_shift_correct == False and args.evaluate_diagonal == False and args.evaluate_save == False):
            model_dict = model.state_dict()
            # pop fc parameters
            new_weights_sd = {}
            for key in weights['state_dict']:
                if 'fc' not in key:
                     new_weights_sd[key] = weights['state_dict'][key]
            model_dict.update(new_weights_sd)
            weights['state_dict'] = model_dict
            print('warmning: please pay attention to weight loading when number of classes not equal to 1000.')

        model.load_state_dict(weights['state_dict'])
        # model.load_state_dict(weights)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            if('optimizer' in checkpoint.keys()): # if no optimizer, then only load weights
                args.start_epoch = checkpoint['epoch']
                best_acc1 = checkpoint['best_acc1']
                if args.gpu is not None:
                    # best_acc1 may be from a checkpoint from a different GPU
                    best_acc1 = best_acc1.to(args.gpu)
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                print('  No optimizer saved')
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.size == 224:
        l_size = 256
        s_size = 224
    elif args.size == 128:
        l_size = 174
        s_size = 128

    if args.dataset == 'imagenet':
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        if(args.no_data_aug):
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.Resize(l_size),
                    transforms.CenterCrop(s_size),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(s_size),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))
    elif args.dataset == 'vid':
        if(args.no_data_aug):
            train_dataset = VidDataset(
                args.data,
                True,
                transforms.Compose([
                    transforms.Resize(l_size),
                    transforms.CenterCrop(s_size),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            train_dataset = VidDataset(
                args.data,
                True,
                transforms.Compose([
                    transforms.RandomResizedCrop(s_size),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))
    else:
        assert False, "Not implemented error."

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    crop_size = l_size if(args.evaluate_shift or args.evaluate_diagonal or args.evaluate_save) else s_size
    args.batch_size = 1 if (args.evaluate_diagonal or args.evaluate_save) else args.batch_size

    if args.dataset == 'imagenet':
        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(l_size),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
    elif args.dataset == 'vid':
        val_loader = torch.utils.data.DataLoader(
            VidDataset(args.data, False, transforms.Compose([
                transforms.Resize(l_size),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
    else:
        assert False, "Not implemented error."

    if(args.val_debug): # debug mode - train on val set for faster epochs
        train_loader = val_loader

    if(args.embed):
        embed()

    if args.save_weights is not None: # "deparallelize" saved weights
        print("=> saving 'deparallelized' weights [%s]"%args.save_weights)
        # TO-DO: automatically save this during training
        if args.gpu is not None:
            torch.save({'state_dict': model.state_dict()}, args.save_weights)
        else:
            if(args.arch[:7]=='alexnet' or args.arch[:3]=='vgg'):
                model.features = model.features.module
                torch.save({'state_dict': model.state_dict()}, args.save_weights)
            else:
                torch.save({'state_dict': model.module.state_dict()}, args.save_weights)
        return

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    if(args.evaluate_shift):
        validate_shift(val_loader, model, args)
        return

    if(args.evaluate_shift_correct):
        validate_shift_correct(val_loader, model, args)
        return

    if(args.evaluate_diagonal):
        validate_diagonal(val_loader, model, args)
        return

    if(args.evaluate_save):
        validate_save(val_loader, mean, std, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1, acc5 = validate(val_loader, model, criterion, args)

        log_file = open(log_pth, 'a')
        log_file.write('epoch: ' + str(epoch) + ', top-1 acc: ' + str(acc1) + ', top-5 acc: ' + str(acc5) + ' \n')
        log_file.close()

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, epoch, out_dir=args.out_dir)