Пример #1
0
def test_not_revalidate():
    error = {
        'revalidate': False,
    }
    container = {}

    error_status = run.validate(container, error)
    assert error_status == ['Skipping revalidation']
Пример #2
0
def test_validate_required_does_not_exist():
    error = {
        'revalidate': True,
        'schema': {
            'required': ['label']
        },
        'item': 'info'
    }
    container = {'info': {}}

    error_status = run.validate(container, error)
    assert error_status == ['\'label\' is a required property']
Пример #3
0
def test_validate_required_does_not_exist():
    error = {
        'revalidate': True,
        'schema': {
            'type': 'string',
            'required': True
        },
        'item': 'label'
    }
    container = {}

    error_status = run.validate(container, error)
    assert error_status == '\'label\' is required'
Пример #4
0
def test_validate_required_exists_valid():
    error = {
        'revalidate': True,
        'schema': {
            'type': 'string',
            'required': True
        },
        'item': 'label'
    }
    container = {'label': 'STRING'}

    error_status = run.validate(container, error)
    assert error_status is True
Пример #5
0
def test_validate_required_exists_invalid():
    error = {
        'revalidate': True,
        'schema': {
            'type': 'string',
            'required': True
        },
        'item': 'label'
    }
    container = {'label': 1}

    error_status = run.validate(container, error)
    assert error_status == '1 is not of type \'string\''
Пример #6
0
def test_validate_item_does_not_exist(caplog):
    error = {
        'revalidate': True,
        'schema': {
            'required': ['label']
        },
        'item': 'info'
    }
    container = {'name': 'test.dcm'}
    with caplog.at_level(logging.WARNING):
        error_status = run.validate(container, error)
        assert error_status[0].endswith(
            'Please confirm metadata are not missing.')
        assert any([
            message.endswith('Please confirm metadata are not missing.')
            for message in caplog.messages
        ])
Пример #7
0
def test_not_revalidate_message():
    error = {'revalidate': False, 'error_message': 'Not enough images'}
    container = {}

    error_status = run.validate(container, error)
    assert error_status == ['Not enough images']
Пример #8
0
def test_validate_does_not_exists():
    error = {'revalidate': True, 'schema': {'type': 'string'}, 'item': 'label'}
    container = {}

    error_status = run.validate(container, error)
    assert error_status is True
Пример #9
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

# create model
#####################################################################################

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            print('=> using pre-trained {}'.format(args.arch))
            model = EfficientNet.from_pretrained(args.arch, advprop=args.advprop)

        else:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
    else:
        if args.arch.startswith('efficientnet-b'):
            print("=> creating model {}".format(args.arch))
            model = EfficientNet.from_name(args.arch)
        elif args.arch.startswith('Dense'):
            print("=> creating model {}".format(args.arch))
            model = DenseNet40()
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    # create teacher model
    if args.kd:
        print('=> loading teacher model')
        if args.teacher_arch.startswith('efficientnet-b'):
            teacher = EfficientNet.from_pretrained(args.teacher_arch)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        elif args.teacher_arch.startswith('resnext101_32'):
            teacher = torch.hub.load('facebookresearch/WSL-Images', '{}_wsl'.format(args.teacher_arch))
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))
        elif args.overhaul:
            teacher = resnet.resnet152(pretrained=True)
        else:
            teacher = models.__dict__[args.teacher_arch](pretrained=True)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        if args.overhaul:
            print('=> using overhaul distillation')
            d_net = Distiller(teacher, model)

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
            if args.kd:
                teacher = torch.nn.DataParallel(teacher).cuda()
                if args.overhaul:
                    d_net = torch.nn.DataParallel(d_net).cuda()

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.pth_path, map_location=loc)
            model.load_state_dict(checkpoint['state_dict'])
#####################################################################################


# define loss function (criterion) and optimizer, scheduler
#####################################################################################
    if args.kd:
        criterion = kd_criterion
        if args.overhaul:
            criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    else:
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)


    if args.overhaul:
        optimizer = torch.optim.SGD(list(model.parameters()) + list(d_net.module.Connectors.parameters()), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)  # nesterov
    else:
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay, amsgrad=False)
        scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs * int(1281167 / args.batch_size), eta_min=0,
                                      last_epoch=-1)
        args.lr = 0.048
        args.bs = 384
        optimizer = torch.optim.RMSprop(
            model.parameters(), lr=args.lr, alpha=0.9, eps=.001,
            momentum=0.9, weight_decay=args.weight_decay)

        from typing import Dict, Any
        class Scheduler:
            """ Parameter Scheduler Base Class
            A scheduler base class that can be used to schedule any optimizer parameter groups.
            Unlike the builtin PyTorch schedulers, this is intended to be consistently called
            * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
            * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
            The schedulers built on this should try to remain as stateless as possible (for simplicity).
            This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
            and -1 values for special behaviour. All epoch and update counts must be tracked in the training
            code and explicitly passed in to the schedulers on the corresponding step or step_update call.
            Based on ideas from:
             * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
             * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
            """

            def __init__(self,
                         optimizer: torch.optim.Optimizer,
                         param_group_field: str,
                         noise_range_t=None,
                         noise_type='normal',
                         noise_pct=0.67,
                         noise_std=1.0,
                         noise_seed=None,
                         initialize: bool = True) -> None:
                self.optimizer = optimizer
                self.param_group_field = param_group_field
                self._initial_param_group_field = f"initial_{param_group_field}"
                if initialize:
                    for i, group in enumerate(self.optimizer.param_groups):
                        if param_group_field not in group:
                            raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
                        group.setdefault(self._initial_param_group_field, group[param_group_field])
                else:
                    for i, group in enumerate(self.optimizer.param_groups):
                        if self._initial_param_group_field not in group:
                            raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
                self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
                self.metric = None  # any point to having this for all?
                self.noise_range_t = noise_range_t
                self.noise_pct = noise_pct
                self.noise_type = noise_type
                self.noise_std = noise_std
                self.noise_seed = noise_seed if noise_seed is not None else 42
                self.update_groups(self.base_values)

            def state_dict(self) -> Dict[str, Any]:
                return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

            def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
                self.__dict__.update(state_dict)

            def get_epoch_values(self, epoch: int):
                return None

            def get_update_values(self, num_updates: int):
                return None

            def step(self, epoch: int, metric: float = None) -> None:
                self.metric = metric
                values = self.get_epoch_values(epoch)
                if values is not None:
                    values = self._add_noise(values, epoch)
                    self.update_groups(values)

            def step_update(self, num_updates: int, metric: float = None):
                self.metric = metric
                values = self.get_update_values(num_updates)
                if values is not None:
                    values = self._add_noise(values, num_updates)
                    self.update_groups(values)

            def update_groups(self, values):
                if not isinstance(values, (list, tuple)):
                    values = [values] * len(self.optimizer.param_groups)
                for param_group, value in zip(self.optimizer.param_groups, values):
                    param_group[self.param_group_field] = value

            def _add_noise(self, lrs, t):
                if self.noise_range_t is not None:
                    if isinstance(self.noise_range_t, (list, tuple)):
                        apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
                    else:
                        apply_noise = t >= self.noise_range_t
                    if apply_noise:
                        g = torch.Generator()
                        g.manual_seed(self.noise_seed + t)
                        if self.noise_type == 'normal':
                            while True:
                                # resample if noise out of percent limit, brute force but shouldn't spin much
                                noise = torch.randn(1, generator=g).item()
                                if abs(noise) < self.noise_pct:
                                    break
                        else:
                            noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
                        lrs = [v + v * noise for v in lrs]
                return lrs
        class StepLRScheduler(Scheduler):
            """
            """

            def __init__(self,
                         optimizer: torch.optim.Optimizer,
                         decay_t: float,
                         decay_rate: float = 1.,
                         warmup_t=0,
                         warmup_lr_init=0,
                         t_in_epochs=True,
                         noise_range_t=None,
                         noise_pct=0.67,
                         noise_std=1.0,
                         noise_seed=42,
                         initialize=True,
                         ) -> None:
                super().__init__(
                    optimizer, param_group_field="lr",
                    noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
                    initialize=initialize)

                self.decay_t = decay_t
                self.decay_rate = decay_rate
                self.warmup_t = warmup_t
                self.warmup_lr_init = warmup_lr_init
                self.t_in_epochs = t_in_epochs
                if self.warmup_t:
                    self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
                    super().update_groups(self.warmup_lr_init)
                else:
                    self.warmup_steps = [1 for _ in self.base_values]

            def _get_lr(self, t):
                if t < self.warmup_t:
                    lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
                else:
                    lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
                return lrs

            def get_epoch_values(self, epoch: int):
                if self.t_in_epochs:
                    return self._get_lr(epoch)
                else:
                    return None

            def get_update_values(self, num_updates: int):
                if not self.t_in_epochs:
                    return self._get_lr(num_updates)
                else:
                    return None

        scheduler = StepLRScheduler(
            optimizer,
            decay_t=2.4,
            decay_rate=0.97,
            warmup_lr_init=1e-6,
            warmup_t=3,
            noise_range_t=None,
            noise_pct=getattr(args, 'lr_noise_pct', 0.67),
            noise_std=getattr(args, 'lr_noise_std', 1.),
            noise_seed=getattr(args, 'seed', 42),
        )
    # scheduler = MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)
    # milestone = np.ceil(np.arange(0,300,2.4))

    # scheduler = MultiStepLR(optimizer, milestones=[30,60,90,120,150,180,210,240,270], gamma=0.1)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
#####################################################################################


# Data loading code
#####################################################################################
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if args.advprop:
        normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])


    train_dataset = ImageFolder_iid(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            ImageNetPolicy(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        ImageFolder_iid(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
#####################################################################################

    if args.evaluate:
        validate(val_loader, model, criterion, args)

# Start training
#####################################################################################
    best_acc1 = 0
    teacher_name = ''
    student_name = ''
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        if args.kd:
            if args.overhaul:
                train_with_overhaul(train_loader, d_net, optimizer, criterion, epoch, args)
                acc1 = validate_overhaul(val_loader, model, criterion, epoch, args)
            else:
                train_kd(train_loader, teacher, model, criterion, optimizer, epoch, args)
                acc1 = validate_kd(val_loader, teacher, model, criterion, args)

                teacher_name = teacher.module.__class__.__name__

        else:
            student_name = model.module.__class__.__name__
            train(train_loader, model, criterion, optimizer, epoch, args)
            acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        #writer.add_scalars('acc1', acc1, epoch)
        is_best = acc1 > best_acc1
        if acc1 < 65:
            print(colored('not saving... accuracy smaller than 65','green'))
            is_best = False
        best_acc1 = max(acc1, best_acc1)


        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, teacher_name=teacher_name, student_name=student_name, save_path=args.save_path, acc=acc1)

        scheduler.step(epoch)
Пример #10
0
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=args.lr,
                                  betas=(0.9, 0.999),
                                  eps=1e-08,
                                  weight_decay=args.weight_decay,
                                  amsgrad=False)
    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=args.epochs *
                                  int(1281167 / args.batch_size),
                                  eta_min=0,
                                  last_epoch=-1)

    best_acc1 = 0
    for epoch in range(0, args.epochs):
        train(train_loader, model, criterion, optimizer, epoch, args)
        acc1 = validate(val_loader, model, criterion, args)

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        model_name = model.module.__class__.__name__
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            teacher_name='',
            student_name=model_name,
Пример #11
0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(80 * "=")
    print("INITIALIZING")
    print(80 * "=")
    start = time.time()
    debug = False
    test_data,pad_action = load_and_preprocess_data_test(opt,parser,debug)
    
    test_batched = batch_dev_test(test_data, opt.batchsize, parser.NULL, parser.P_NULL,
                                  parser ,no_sort=False)

    model = ParserModel(parser.embedding_shape, device, parser, pad_action, opt)
    
    model.load_state_dict(checkpoint['model'],strict=False)
    
    model = model.to(device)
    print("took {:.2f} seconds\n".format(time.time() - start))
    
    print(80 * "=")
    print("TESTING")
    print(80 * "=")
    print("Final evaluation on test set", )
    model.eval()

    UAS, LAS = validate(model, parser, test_batched, test_data, device, opt.batchsize,pad_action['P'],opt)
    print("- test UAS: {:.2f}".format(UAS * 100.0))
    print("- test LAS: {:.2f}".format(LAS * 100.0))
    print("Done!")

Пример #12
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

# create model
#####################################################################################

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            print('=> using pre-trained {}'.format(args.arch))
            model = EfficientNet.from_pretrained(args.arch,
                                                 advprop=args.advprop)

        else:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
    else:
        if args.arch.startswith('efficientnet-b'):
            print("=> creating model {}".format(args.arch))
            model = EfficientNet.from_name(args.arch)
        elif args.arch.startswith('Dense'):
            print("=> creating model {}".format(args.arch))
            model = DenseNet40()
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    # create teacher model
    if args.kd:
        print('=> loading teacher model')
        if args.teacher_arch.startswith('efficientnet-b'):
            teacher = EfficientNet.from_pretrained(args.teacher_arch)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        elif args.teacher_arch.startswith('resnext101_32'):
            teacher = torch.hub.load('facebookresearch/WSL-Images',
                                     '{}_wsl'.format(args.teacher_arch))
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))
        elif args.overhaul:
            teacher = resnet.resnet152(pretrained=True)
        else:
            teacher = models.__dict__[args.teacher_arch](pretrained=True)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        if args.overhaul:
            print('=> using overhaul distillation')
            d_net = Distiller(teacher, model)

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
            if args.kd:
                teacher = torch.nn.DataParallel(teacher).cuda()
                if args.overhaul:
                    d_net = torch.nn.DataParallel(d_net).cuda()

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.pth_path, map_location=loc)
            model.load_state_dict(checkpoint['state_dict'])
#####################################################################################

# define loss function (criterion) and optimizer, scheduler
#####################################################################################
    if args.kd:
        criterion = kd_criterion
        if args.overhaul:
            criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    else:
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    if args.overhaul:
        optimizer = torch.optim.SGD(list(model.parameters()) +
                                    list(d_net.module.Connectors.parameters()),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)  # nesterov
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=args.lr,
                                      betas=(0.9, 0.999),
                                      eps=1e-08,
                                      weight_decay=args.weight_decay,
                                      amsgrad=False)
        scheduler = CosineAnnealingLR(optimizer,
                                      T_max=args.epochs *
                                      int(1281167 / args.batch_size),
                                      eta_min=0,
                                      last_epoch=-1)
        args.lr = 0.048
        args.batch_size = 384
        parameters = add_weight_decay(model, 1e-5)
        optimizer = RMSpropTF(parameters,
                              lr=0.048,
                              alpha=0.9,
                              eps=0.001,
                              momentum=args.momentum,
                              weight_decay=1e-5)

        scheduler = StepLRScheduler(
            optimizer,
            decay_t=2.4,
            decay_rate=0.97,
            warmup_lr_init=1e-6,
            warmup_t=3,
            noise_range_t=None,
            noise_pct=getattr(args, 'lr_noise_pct', 0.67),
            noise_std=getattr(args, 'lr_noise_std', 1.),
            noise_seed=getattr(args, 'seed', 42),
        )
    # scheduler = MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)
    # milestone = np.ceil(np.arange(0,300,2.4))

    # scheduler = MultiStepLR(optimizer, milestones=[30,60,90,120,150,180,210,240,270], gamma=0.1)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    #####################################################################################

    # Data loading code
    #####################################################################################
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if args.advprop:
        normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    train_dataset = ImageFolder_iid(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            ImageNetPolicy(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(ImageFolder_iid(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    #####################################################################################

    if args.evaluate:
        validate(val_loader, model, criterion, args)

# Start training
#####################################################################################
    best_acc1 = 0
    teacher_name = ''
    student_name = ''

    ema = EMA(model, 0.9999)
    ema.register()
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        if args.kd:
            if args.overhaul:
                train_with_overhaul(train_loader, d_net, optimizer, criterion,
                                    epoch, args)
                acc1 = validate_overhaul(val_loader, model, criterion, epoch,
                                         args)
            else:
                train_kd(train_loader, teacher, model, criterion, optimizer,
                         epoch, args)
                acc1 = validate_kd(val_loader, teacher, model, criterion, args)

                teacher_name = teacher.module.__class__.__name__

        else:
            student_name = model.module.__class__.__name__
            train(train_loader, model, criterion, optimizer, epoch, args, ema)
            acc1 = validate(val_loader, model, criterion, args, ema)

        # remember best acc@1 and save checkpoint
        #writer.add_scalars('acc1', acc1, epoch)
        is_best = acc1 > best_acc1
        if acc1 < 65:
            print(colored('not saving... accuracy smaller than 65', 'green'))
            is_best = False
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                teacher_name=teacher_name,
                student_name=student_name,
                save_path=args.save_path,
                acc=acc1)

        scheduler.step(epoch)