Ejemplo n.º 1
0
    config.net_config, net_type = utils.load_net_config(
        os.path.join(args.load_path, 'net_config'))

    derivedNetwork = getattr(model_derived, '%s_Net' % net_type.upper())
    model = derivedNetwork(config.net_config, config=config, num_classes=1000)

    logging.info("Network Structure: \n" +
                 '\n'.join(map(str, model.net_config)))
    logging.info("Params = %.2fMB" % utils.count_parameters_in_MB(model))
    logging.info("Mult-Adds = %.2fMB" %
                 comp_multadds(model, input_size=config.data.input_size))

    model = model.cuda()
    model = nn.DataParallel(model)
    checkpoint = torch.load(os.path.join(args.load_path, 'weight.pt'),
                            map_location="cpu")  # weight checkpoint
    model.load_state_dict(checkpoint['state_dict'], strict=False)

    imagenet = imagenet_data.ImageNet12(
        trainFolder=os.path.join(args.data_path, 'train'),
        testFolder=os.path.join(args.data_path, 'val'),
        num_workers=config.data.num_workers,
        data_config=config.data)
    valid_queue = imagenet.getTestLoader(config.data.batch_size)
    trainer = Trainer(None, valid_queue, None, None, None, config,
                      args.report_freq)

    with torch.no_grad():
        val_acc_top1, val_acc_top5, valid_obj, batch_time = trainer.infer(
            model)
Ejemplo n.º 2
0
    model = model.cuda()

    if config.optim.label_smooth:
        criterion = utils.cross_entropy_with_label_smoothing
    else:
        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                config.optim.init_lr,
                                momentum=config.optim.momentum,
                                weight_decay=config.optim.weight_decay)

    imagenet = imagenet_data.ImageNet12(
        trainFolder=os.path.join(args.data_path, 'train'),
        testFolder=os.path.join(args.data_path, 'val'),
        num_workers=config.data.num_workers,
        type_of_data_augmentation=config.data.type_of_data_aug,
        data_config=config.data)

    if config.optim.use_multi_stage:
        (train_queue,
         week_train_queue), valid_queue = imagenet.getSetTrainTestLoader(
             config.data.batch_size)
    else:
        train_queue, valid_queue = imagenet.getTrainTestLoader(
            config.data.batch_size)

    scheduler = get_lr_scheduler(config, optimizer,
                                 train_queue.dataset.__len__())
    scheduler.last_step = start_epoch * (
        train_queue.dataset.__len__() // config.data.batch_size + 1) - 1
Ejemplo n.º 3
0
def main():
    global best_err1, best_err5

    if config.train_params.use_seed:
        utils.set_seed(config.train_params.seed)

    imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(config.data.data_path, 'train'),
                                        testFolder=os.path.join(config.data.data_path, 'val'),
                                        num_workers=config.data.num_workers,
                                        type_of_data_augmentation=config.data.type_of_data_aug,
                                        data_config=config.data)

    train_loader, val_loader = imagenet.getTrainTestLoader(config.data.batch_size)

    if config.net_type == 'mobilenet':
        t_net = ResNet.resnet50(pretrained=True)
        s_net = Mov.MobileNet()
    elif config.net_type == 'resnet':
        t_net = ResNet.resnet34(pretrained=True)
        s_net = ResNet.resnet18(pretrained=False)
    else:
        print('undefined network type !!!')
        raise RuntimeError('%s does not support' % config.net_type)

    import knowledge_distiller
    d_net = knowledge_distiller.WSLDistiller(t_net, s_net)

    print('Teacher Net: ')
    print(t_net)
    print('Student Net: ')
    print(s_net)
    print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()])))
    print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()])))

    t_net = torch.nn.DataParallel(t_net)
    s_net = torch.nn.DataParallel(s_net)
    d_net = torch.nn.DataParallel(d_net)

    if config.optim.if_resume:
        checkpoint = torch.load(config.optim.resume_path)
        d_net.module.load_state_dict(checkpoint['train_state_dict'])
        best_err1 = checkpoint['best_err1']
        best_err5 = checkpoint['best_err5']
        start_epoch = checkpoint['epoch'] + 1
    else:
        start_epoch = 0


    t_net = t_net.cuda()
    s_net = s_net.cuda()
    d_net = d_net.cuda()

    ### choose optimizer parameters

    optimizer = torch.optim.SGD(list(s_net.parameters()), config.optim.init_lr,
                                momentum=config.optim.momentum, weight_decay=config.optim.weight_decay, nesterov=True)

    cudnn.benchmark = True
    cudnn.enabled = True

    print('Teacher network performance')
    validate(val_loader, t_net, 0)

    for epoch in range(start_epoch, config.train_params.epochs + 1):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_with_distill(train_loader, d_net, optimizer, epoch)

        # evaluate on validation set
        err1, err5 = validate(val_loader, s_net, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5
        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'state_dict': s_net.module.state_dict(),
            'train_state_dict': d_net.module.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer': optimizer.state_dict(),
        }, is_best)
        gc.collect()

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)