Ejemplo n.º 1
0
def do_training(args):
    trainloader, testloader = build_dataset(
        args.dataset,
        dataroot=args.dataroot,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        num_workers=2)
    model = build_model(args.arch, num_classes=num_classes(args.dataset))
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    # Calculate total number of model parameters
    num_params = sum(p.numel() for p in model.parameters())
    track.metric(iteration=0, num_params=num_params)

    if args.optimizer == 'sgd':
        optimizer = SGD(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
    else:
        optimizer = EKFAC(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay,
                          eps=args.eps,
                          update_freq=args.update_freq)

    criterion = torch.nn.CrossEntropyLoss()

    best_acc = 0.0
    for epoch in range(args.epochs):
        track.debug("Starting epoch %d" % epoch)
        args.lr = adjust_learning_rate(epoch, optimizer, args.lr,
                                       args.schedule, args.gamma)
        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, args.cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   args.cuda)
        track.debug('Finished epoch %d... | train loss %.3f | train acc %.3f '
                    '| test loss %.3f | test acc %.3f' %
                    (epoch, train_loss, train_acc, test_loss, test_acc))
        # Save model
        model_fname = os.path.join(track.trial_dir(),
                                   "model{}.ckpt".format(epoch))
        torch.save(model, model_fname)
        if test_acc > best_acc:
            best_acc = test_acc
            best_fname = os.path.join(track.trial_dir(), "best.ckpt")
            track.debug("New best score! Saving model")
            torch.save(model, best_fname)
Ejemplo n.º 2
0
def do_training(args):
    trainloader, testloader = build_dataset(
        args.dataset,
        dataroot=args.dataroot,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        num_workers=2)
    model = build_model(args.arch, num_classes=num_classes(args.dataset))
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    # Calculate total number of model parameters
    num_params = sum(p.numel() for p in model.parameters())
    track.metric(iteration=0, num_params=num_params)

    num_chunks = max(1, args.batch_size // args.max_samples_per_gpu)

    optimizer = LARS(params=model.parameters(),
                     lr=args.lr,
                     momentum=args.momentum,
                     weight_decay=args.weight_decay,
                     eta=args.eta,
                     max_epoch=args.epochs)

    criterion = torch.nn.CrossEntropyLoss()

    best_acc = 0.0
    for epoch in range(args.epochs):
        track.debug("Starting epoch %d" % epoch)
        train_loss, train_acc = train(trainloader,
                                      model,
                                      criterion,
                                      optimizer,
                                      epoch,
                                      args.cuda,
                                      num_chunks=num_chunks)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   args.cuda)
        track.debug('Finished epoch %d... | train loss %.3f | train acc %.3f '
                    '| test loss %.3f | test acc %.3f' %
                    (epoch, train_loss, train_acc, test_loss, test_acc))
        # Save model
        model_fname = os.path.join(track.trial_dir(),
                                   "model{}.ckpt".format(epoch))
        torch.save(model, model_fname)
        if test_acc > best_acc:
            best_acc = test_acc
            best_fname = os.path.join(track.trial_dir(), "best.ckpt")
            track.debug("New best score! Saving model")
            torch.save(model, best_fname)
Ejemplo n.º 3
0
def _main(_):
    with track.trial(os.getenv('TRACK_DIRECTORY'),
                     param_map=track.absl_flags(),
                     trial_prefix=flags.FLAGS.trial_prefix):
        seed_all(flags.FLAGS.seed)
        track.debug('found gpus {}', gpus())

        dataset_file = os.path.join(
            flags.FLAGS.dataroot, 'wikisql',
            'processed-toy{}.pth'.format(1 if flags.FLAGS.toy else 0))
        track.debug('loading data from {}', dataset_file)
        train, val, _ = torch.load(dataset_file)

        track.debug('building model')
        model = wikisql_specific.WikiSQLSpecificModel(train.fields)
        track.debug('built model:\n{}', model)
        num_parameters = int(
            sum(p.numel() for p in model.parameters() if p.requires_grad))
        track.debug('number of parameters in model {}', num_parameters)

        device = get_device()
        torch.save(model.to(torch.device('cpu')),
                   os.path.join(track.trial_dir(), 'untrained_model.pth'))
        model = model.to(device)
        training_state = _TrainingState()
        if flags.FLAGS.restore_checkpoint:
            _copy_best_checkpoint(flags.FLAGS.restore_checkpoint)
            _load_checkpoint(flags.FLAGS.restore_checkpoint, model,
                             training_state)
        params_to_optimize = [p for p in model.parameters() if p.requires_grad]
        if flags.FLAGS.optimizer == 'sgd':
            # lr required here but will be set in _do_training
            optimizer = optim.SGD(params_to_optimize,
                                  lr=1,
                                  weight_decay=flags.FLAGS.weight_decay)
        elif flags.FLAGS.optimizer == 'momentum':
            # lr required here but will be set in _do_training
            optimizer = optim.SGD(params_to_optimize,
                                  lr=1,
                                  momentum=0.9,
                                  weight_decay=flags.FLAGS.weight_decay)
        elif flags.FLAGS.optimizer == 'adam':
            optimizer = optim.Adam(params_to_optimize,
                                   weight_decay=flags.FLAGS.weight_decay)
        else:
            raise ValueError('unrecognized optimizer {}'.format(
                flags.FLAGS.optimizer))

        num_workers = flags.FLAGS.workers
        track.debug('initializing {} workers', num_workers)
        with closing(SharedGPU(optimizer, model, num_workers)) as shared:
            _do_training(train, val, shared, training_state)
Ejemplo n.º 4
0
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(
                batch_idx,
                len(testloader),
                "Loss: %.3f | Acc: %.3f%% (%d/%d)" % (
                    test_loss / (batch_idx + 1),
                    100.0 * correct / total,
                    correct,
                    total,
                ),
            )

    # Save checkpoint.
    acc = 100.0 * correct / total
    if acc > best_acc:
        print("Saving..")
        state = {"net": net.state_dict(), "acc": acc, "epoch": epoch}
        if not os.path.isdir("checkpoint"):
            os.mkdir("checkpoint")
        ckpt_path = os.path.join(track.trial_dir(), "ckpt.pth")
        torch.save(state, ckpt_path)
        best_acc = acc
    test_loss = test_loss / len(testloader)
    return test_loss, acc, best_acc
Ejemplo n.º 5
0
            epoch_start_time = time.time()
            train_loss = train()
            val_loss = evaluate(val_data)
            print('-' * 89)
            track.debug(
                '| end of epoch {:3d} | time: {:5.2f}s | train loss {:5.2f} | valid loss {:5.2f} | '
                'valid ppl {:8.2f}'.format(epoch,
                                           (time.time() - epoch_start_time),
                                           train_loss, val_loss,
                                           math.exp(val_loss)))
            print('-' * 89)
            track.metric(iteration=epoch,
                         train_loss=train_loss,
                         test_loss=val_loss)
            # Log model
            model_fname = os.path.join(track.trial_dir(),
                                       "model{}.ckpt".format(epoch))
            torch.save(model, model_fname)
            # Save the model if the validation loss is the best we've seen so far.
            if not best_val_loss or val_loss < best_val_loss:
                best_fname = os.path.join(track.trial_dir(), "best.ckpt")
                with open(best_fname, 'wb') as f:
                    torch.save(model, f)
                best_val_loss = val_loss
            else:
                # Anneal the learning rate if no improvement has been seen in the validation dataset.
                lr /= 4.0
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')
Ejemplo n.º 6
0
def do_training(args):

    hyperparameters = {
        'lr': args.lr,
        'epochs': args.epochs,
        'resume_from': 0,
        'coco_version': args.coco_version,  #can be either '2014' or '2017'
        'batch_size': args.batch_size,
        'weight_decay': args.weight_decay,
        'momentum': args.momentum,
        'optimizer': args.optimizer,
        'alpha': args.alpha,
        'gamma': args.gamma,
        'lcoord': args.lcoord,
        'lno_obj': args.lno_obj,
        'iou_type': tuple(int(a) for a in tuple(args.iou_type)),
        'iou_ignore_thresh': args.iou_ignore_thresh,
        'tfidf': args.tfidf,
        'idf_weights': True,
        'tfidf_col_names': ['img_freq', 'none', 'none', 'none', 'no_softmax'],
        'wasserstein': args.wasserstein,
        'inf_confidence': args.inf_confidence,
        'inf_iou_threshold': args.inf_iou_threshold,
        'augment': args.augment,
        'workers': 1,
        'pretrained': args.is_pretrained,
        'path': args.trial_id,
        'reduction': args.reduction
    }

    mode = {
        'bayes_opt': False,
        'multi_scale': args.multi_scale,
        'show_hp': args.show_hp,
        'show_output': args.show_output,
        'multi_gpu': False,
        'train_subset': args.train_subset,
        'test_subset': args.test_subset,
        'show_temp_summary': args.show_temp_summary,
        'save_summary': False
    }

    this_proj = track.Project("./logs/" + args.experimentname)
    if (args.resume == 'last'):
        this_proj = track.Project("./logs/" + args.experimentname)
        most_recent = this_proj.ids["start_time"].nlargest(2).idxmin()
        most_recent_id = this_proj.ids["trial_id"].iloc[[most_recent]]
        PATH = os.path.join("./logs/" + args.experimentname,
                            most_recent_id.item())
        hyperparameters['path'] = os.path.join(PATH, 'last.tar')
        args.resume = most_recent_id.item()
    elif (args.resume == 'best'):
        ids = this_proj.ids["trial_id"]
        res = this_proj.results(ids)
        best_map = res["coco_stats:map_all"].idxmax()
        best_map_id = res["trial_id"].iloc[[best_map]]
        PATH = os.path.join("./logs/" + args.experimentname,
                            best_map_id.item())
        hyperparameters['path'] = os.path.join(PATH, 'best.tar')
        args.resume = best_map_id.item()
    else:
        PATH = os.path.join("./logs/" + args.experimentname, args.resume)
        hyperparameters['path'] = os.path.join(PATH, 'last.tar')

    coco_version = hyperparameters['coco_version']
    mAP_best = 0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model, optimizer, hyperparameters, PATH = init_model.init_model(
        hyperparameters, mode)

    model.hp = hyperparameters
    model.mode = mode

    if type(model) is nn.DataParallel:
        inp_dim = model.module.inp_dim
    else:
        inp_dim = model.inp_dim

    if hyperparameters['augment'] > 0:
        train_dataset = Coco(partition='train',
                             coco_version=coco_version,
                             subset=mode['train_subset'],
                             transform=transforms.Compose([
                                 Augment(hyperparameters['augment']),
                                 ResizeToTensor(inp_dim)
                             ]))
    else:
        train_dataset = Coco(partition='train',
                             coco_version=coco_version,
                             subset=mode['train_subset'],
                             transform=transforms.Compose(
                                 [ResizeToTensor(inp_dim)]))

    batch_size = hyperparameters['batch_size']

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  collate_fn=helper.collate_fn,
                                  num_workers=hyperparameters['workers'],
                                  pin_memory=True)

    test_dataset = Coco(partition='val',
                        coco_version=coco_version,
                        subset=mode['test_subset'],
                        transform=transforms.Compose([ResizeToTensor(inp_dim)
                                                      ]))

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.eval_batch_size,
                                 shuffle=False,
                                 collate_fn=helper.collate_fn,
                                 num_workers=1,
                                 pin_memory=True)

    # Calculate total number of model parameters
    num_params = sum(p.numel() for p in model.parameters())
    track.metric(iteration=0, num_params=num_params)

    for epoch in range(args.epochs):
        track.debug("Starting epoch %d" % epoch)
        #         args.lr = adjust_learning_rate(epoch, optimizer, args.lr, args.schedule,
        #                                        args.gamma)

        outcome = train(train_dataloader, model, optimizer, epoch)

        mAP = 0
        mAP = test(test_dataloader, model, epoch, device)

        track.debug(
            'Finished epoch %d... | train loss %.3f | avg_iou %.3f | avg_conf %.3f | avg_no_conf %.3f'
            '| avg_pos %.3f | avg_neg %.5f | mAP %.5f' %
            (epoch, outcome['avg_loss'], outcome['avg_iou'],
             outcome['avg_conf'], outcome['avg_no_conf'], outcome['avg_pos'],
             outcome['avg_neg'], mAP))

        model_fname = os.path.join(track.trial_dir(), "last.tar")
        torch.save(
            {
                'model_state_dict':
                model.module.state_dict()
                if type(model) is nn.DataParallel else model.state_dict(),
                'optimizer_state_dict':
                optimizer.state_dict(),
                'avg_loss':
                outcome['avg_loss'],
                'avg_iou':
                outcome['avg_iou'],
                'avg_pos':
                outcome['avg_pos'],
                'avg_neg':
                outcome['avg_neg'],
                'avg_conf':
                outcome['avg_conf'],
                'avg_no_conf':
                outcome['avg_no_conf'],
                'mAP':
                mAP,
                'hyperparameters':
                hyperparameters
            }, model_fname)

        if mAP > mAP_best:
            mAP_best = mAP
            best_fname = os.path.join(track.trial_dir(), "best.tar")
            track.debug("New best score! Saving model")

            torch.save(
                {
                    'model_state_dict':
                    model.module.state_dict()
                    if type(model) is nn.DataParallel else model.state_dict(),
                    'optimizer_state_dict':
                    optimizer.state_dict(),
                    'avg_loss':
                    outcome['avg_loss'],
                    'avg_iou':
                    outcome['avg_iou'],
                    'avg_pos':
                    outcome['avg_pos'],
                    'avg_neg':
                    outcome['avg_neg'],
                    'avg_conf':
                    outcome['avg_conf'],
                    'avg_no_conf':
                    outcome['avg_no_conf'],
                    'mAP':
                    mAP,
                    'hyperparameters':
                    hyperparameters
                }, best_fname)
Ejemplo n.º 7
0
def train(trainloader, model, optimizer, epoch, cuda=True):
    # switch to train mode
    model.train()
    hyperparameters = model.hp
    mode = model.mode

    if type(model) is nn.DataParallel:
        inp_dim = model.module.inp_dim
        pw_ph = model.module.pw_ph
        cx_cy = model.module.cx_cy
        stride = model.module.stride
    else:
        inp_dim = model.inp_dim
        pw_ph = model.pw_ph
        cx_cy = model.cx_cy
        stride = model.stride

    if cuda:
        pw_ph = pw_ph.cuda()
        cx_cy = cx_cy.cuda()
        stride = stride.cuda()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    avg_loss = AverageMeter()
    avg_iou = AverageMeter()
    avg_conf = AverageMeter()
    avg_no_conf = AverageMeter()
    avg_pos = AverageMeter()
    avg_neg = AverageMeter()
    end = time.time()
    break_flag = 0

    if mode['show_temp_summary'] == True:
        writer = SummaryWriter(os.path.join(track.trial_dir(), 'temp_vis/'))

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        if cuda:
            inputs = inputs.cuda()

        # compute output
        raw_pred = model(inputs, torch.cuda.is_available())
        true_pred = util.transform(raw_pred.clone().detach(), pw_ph, cx_cy,
                                   stride)
        iou_list = util.get_iou_list(true_pred, targets, hyperparameters,
                                     inp_dim)

        resp_raw_pred, resp_cx_cy, resp_pw_ph, resp_stride, no_obj = util.build_tensors(
            raw_pred, iou_list, pw_ph, cx_cy, stride, hyperparameters)

        stats = helper.get_progress_stats(true_pred, no_obj, iou_list, targets)
        if hyperparameters['wasserstein'] == True:
            no_obj = util.get_wasserstein_matrices(raw_pred, iou_list, inp_dim)

        try:
            loss = util.yolo_loss(resp_raw_pred, targets, no_obj, resp_pw_ph,
                                  resp_cx_cy, resp_stride, inp_dim,
                                  hyperparameters)
        except RuntimeError:
            print('bayes opt failed')
            break_flag = 1
            break

        # measure accuracy and record loss
        avg_loss.update(loss.item())
        avg_iou.update(stats['iou'])
        avg_conf.update(stats['pos_conf'])
        avg_no_conf.update(stats['neg_conf'])
        avg_pos.update(stats['pos_class'])
        avg_neg.update(stats['neg_class'])

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if mode['show_output'] == True:  # plot progress
            progress_str = 'Loss: %.4f | AvIoU: %.3f | AvPConf: %.3f | AvNConf: %.5f | AvClass: %.3f | AvNClass: %.5f'\
                % (loss.item(), stats['iou'], stats['pos_conf'], stats['neg_conf'],stats['pos_class'],stats['neg_class'])
            progress_bar(batch_idx, len(trainloader), progress_str)

        iteration = epoch * len(trainloader) + batch_idx

        if mode['show_temp_summary'] == True:
            writer.add_scalar('AvLoss/train', avg_loss.avg, iteration)
            writer.add_scalar('AvIoU/train', avg_iou.avg, iteration)
            writer.add_scalar('AvPConf/train', avg_conf.avg, iteration)
            writer.add_scalar('AvNConf/train', avg_no_conf.avg, iteration)
            writer.add_scalar('AvClass/train', avg_pos.avg, iteration)
            writer.add_scalar('AvNClass/train', avg_neg.avg, iteration)

    track.metric(iteration=iteration,
                 epoch=epoch,
                 avg_train_loss=avg_loss.avg,
                 avg_train_iou=avg_iou.avg,
                 avg_train_conf=avg_conf.avg,
                 avg_train_neg_conf=avg_no_conf.avg,
                 avg_train_pos=avg_pos.avg,
                 avg_train_neg=avg_neg.avg)

    outcome = {
        'avg_loss': avg_loss.avg,
        'avg_iou': avg_iou.avg,
        'avg_pos': avg_pos.avg,
        'avg_neg': avg_neg.avg,
        'avg_conf': avg_conf.avg,
        'avg_no_conf': avg_no_conf.avg,
        'broken': break_flag
    }

    return outcome
Ejemplo n.º 8
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.sqrt_lr:
        lr = args.lr * math.sqrt(args.batch_size / 32.)
    else:
        lr = args.lr

    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=min(
                                                   args.batch_size,
                                                   args.max_samples),
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.max_samples,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    with track.trial(args.logroot,
                     None,
                     param_map={'batch_size': args.batch_size}):
        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)
            adjust_learning_rate(optimizer, epoch)

            # train for one epoch
            train_loss = train(train_loader, model, criterion, optimizer,
                               epoch)

            # evaluate on validation set
            with torch.no_grad():
                val_loss, prec1 = validate(val_loader, model, criterion)

            track.metric(iteration=epoch,
                         train_loss=train_loss,
                         test_loss=val_loss,
                         prec=prec1)
            # Log model
            model_fname = os.path.join(track.trial_dir(),
                                       "model{}.ckpt".format(epoch))
            torch.save(model, model_fname)

            # Save the model if the validation loss is the best we've seen so far.
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            if is_best:
                best_fname = os.path.join(track.trial_dir(), "best.ckpt")
                with open(best_fname, 'wb') as f:
                    torch.save(model, f)
Ejemplo n.º 9
0
def _checkpoint_file(basename):
    checkpoint_file = os.path.join(track.trial_dir(), 'checkpoints', basename)
    return checkpoint_file
Ejemplo n.º 10
0
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# net = ShuffleNetV2(1)
# net = EfficientNetB0()
net = ResNet34()
net = net.to(device)
if device == "cuda":
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print("==> Resuming from checkpoint..")
    assert os.path.isdir("checkpoint"), "Error: no checkpoint directory found!"
    ckpt_path = os.path.join(track.trial_dir(), "ckpt.pth")
    checkpoint = torch.load(ckpt_path)
    net.load_state_dict(checkpoint["net"])
    best_acc = checkpoint["acc"]
    start_epoch = checkpoint["epoch"]

criterion = nn.CrossEntropyLoss()
optimizer = SGDLRD(
    net.parameters(),
    lr=args.lr,
    lr_dropout_rate=args.lr_dropout_rate,
    momentum=0.9,
    weight_decay=5e-4,
)

lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,