Exemplo n.º 1
0
def main():
    global BEST_ACC, LR_STATE
    start_epoch = cfg.CLS.start_epoch  # start from epoch 0 or last checkpoint epoch

    # Create ckpt folder
    if not os.path.isdir(cfg.CLS.ckpt):
        mkdir_p(cfg.CLS.ckpt)
    if args.cfg_file is not None and not cfg.CLS.evaluate:
        shutil.copyfile(
            args.cfg_file,
            os.path.join(cfg.CLS.ckpt,
                         args.cfg_file.split('/')[-1]))

    # Dataset and Loader
    normalize = transforms.Normalize(mean=cfg.pixel_mean, std=cfg.pixel_std)
    train_aug = [
        transforms.RandomResizedCrop(cfg.CLS.crop_size),
        transforms.RandomHorizontalFlip()
    ]
    if len(cfg.CLS.rotation) > 0:
        train_aug.append(transforms.RandomRotation(cfg.CLS.rotation))
    if len(cfg.CLS.pixel_jitter) > 0:
        train_aug.append(RandomPixelJitter(cfg.CLS.pixel_jitter))
    if cfg.CLS.grayscale > 0:
        train_aug.append(transforms.RandomGrayscale(cfg.CLS.grayscale))
    train_aug.append(transforms.ToTensor())
    train_aug.append(normalize)

    traindir = os.path.join(cfg.CLS.data_root, cfg.CLS.train_folder)
    train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        traindir, transforms.Compose(train_aug)),
                                               batch_size=cfg.CLS.train_batch,
                                               shuffle=True,
                                               num_workers=cfg.workers,
                                               pin_memory=True)

    if cfg.CLS.validate or cfg.CLS.evaluate:
        valdir = os.path.join(cfg.CLS.data_root, cfg.CLS.val_folder)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(cfg.CLS.base_size),
                transforms.CenterCrop(cfg.CLS.crop_size),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=cfg.CLS.test_batch,
                                                 shuffle=False,
                                                 num_workers=cfg.workers,
                                                 pin_memory=True)

    # Create model
    model = models.__dict__[cfg.CLS.arch]()
    print(model)
    # Calculate FLOPs & Param
    n_flops, n_convops, n_params = measure_model(model, cfg.CLS.crop_size,
                                                 cfg.CLS.crop_size)
    print('==> FLOPs: {:.4f}M, Conv_FLOPs: {:.4f}M, Params: {:.4f}M'.format(
        n_flops / 1e6, n_convops / 1e6, n_params / 1e6))
    del model
    model = models.__dict__[cfg.CLS.arch]()

    # Load pre-train model
    if cfg.CLS.pretrained:
        print("==> Using pre-trained model '{}'".format(cfg.CLS.pretrained))
        pretrained_dict = torch.load(cfg.CLS.pretrained)
        try:
            pretrained_dict = pretrained_dict['state_dict']
        except:
            pretrained_dict = pretrained_dict
        model_dict = model.state_dict()
        updated_dict, match_layers, mismatch_layers = weight_filler(
            pretrained_dict, model_dict)
        model_dict.update(updated_dict)
        model.load_state_dict(model_dict)
    else:
        print("==> Creating model '{}'".format(cfg.CLS.arch))

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    if cfg.CLS.pretrained:

        def param_filter(param):
            return param[1]

        new_params = map(
            param_filter,
            filter(lambda p: p[0] in mismatch_layers,
                   model.named_parameters()))
        base_params = map(
            param_filter,
            filter(lambda p: p[0] in match_layers, model.named_parameters()))
        model_params = [{
            'params': base_params
        }, {
            'params': new_params,
            'lr': cfg.CLS.base_lr * 10
        }]
    else:
        model_params = model.parameters()
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    optimizer = optim.SGD(model_params,
                          lr=cfg.CLS.base_lr,
                          momentum=cfg.CLS.momentum,
                          weight_decay=cfg.CLS.weight_decay)

    # Evaluate model
    if cfg.CLS.evaluate:
        print('\n==> Evaluation only')
        test_loss, test_top1, test_top5 = test(val_loader, model, criterion,
                                               start_epoch, USE_CUDA)
        print(
            '==> Test Loss: {:.8f} | Test_top1: {:.4f}% | Test_top5: {:.4f}%'.
            format(test_loss, test_top1, test_top5))
        return

    # Resume training
    title = 'Pytorch-CLS-' + cfg.CLS.arch
    if cfg.CLS.resume:
        # Load checkpoint.
        print("==> Resuming from checkpoint '{}'".format(cfg.CLS.resume))
        assert os.path.isfile(
            cfg.CLS.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(cfg.CLS.resume)
        BEST_ACC = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(cfg.CLS.ckpt, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(cfg.CLS.ckpt, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    # Train and val
    for epoch in range(start_epoch, cfg.CLS.epochs):
        print('\nEpoch: [{}/{}] | LR: {:.8f}'.format(epoch + 1, cfg.CLS.epochs,
                                                     LR_STATE))

        train_loss, train_acc = mixup_train(train_loader, model, criterion,
                                            optimizer, epoch, USE_CUDA)
        if cfg.CLS.validate:
            test_loss, test_top1, test_top5 = test(val_loader, model,
                                                   criterion, epoch, USE_CUDA)
        else:
            test_loss, test_top1, test_top5 = 0.0, 0.0, 0.0

        # Append logger file
        logger.append([LR_STATE, train_loss, test_loss, train_acc, test_top1])
        # Save model
        save_checkpoint(model, optimizer, test_top1, epoch)
        # Draw curve
        try:
            draw_curve(cfg.CLS.arch, cfg.CLS.ckpt)
            print('==> Success saving log curve...')
        except:
            print('==> Saving log curve error...')

    logger.close()
    try:
        savefig(os.path.join(cfg.CLS.ckpt, 'log.eps'))
        shutil.copyfile(
            os.path.join(cfg.CLS.ckpt, 'log.txt'),
            os.path.join(
                cfg.CLS.ckpt, 'log{}.txt'.format(
                    datetime.datetime.now().strftime('%Y%m%d%H%M%S'))))
    except:
        print('copy log error.')
    print('==> Training Done!')
    print('==> Best acc: {:.4f}%'.format(BEST_ACC))
Exemplo n.º 2
0
# ---------------------------------------------------

# Set GPU id, CUDA and cudnn
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
USE_CUDA = torch.cuda.is_available()
cudnn.benchmark = True

# Create & Load model
MODEL = models.__dict__[args.arch]()
checkpoint = torch.load(args.model_weights)
try:
    weight_dict = checkpoint['state_dict']
except:
    weight_dict = checkpoint
model_dict = MODEL.state_dict()
updated_dict, match_layers, mismatch_layers = weight_filler(
    weight_dict, model_dict)
model_dict.update(updated_dict)
MODEL.load_state_dict(model_dict)

# Switch to evaluate mode
MODEL.cuda().eval()
print(MODEL)

# Create log & dict
LOG_PTH = './data/log{}.txt'.format(
    datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
SET_DICT = dict()
f = open(args.val_file, 'r')
img_order = 0
for _ in f:
    img_dict = dict()