def _main():
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    test_dir = "../input/deepfake-detection-challenge/test_videos"
    csv_path = "../input/deepfake-detection-challenge/sample_submission.csv"
    face_detector = FaceDetector()
    face_detector.load_checkpoint("../input/pretrained/RetinaFace-Resnet50-fixed.pth")
    loader = DFDCLoader(test_dir, face_detector, T.ToTensor())

    model1 = xception(num_classes=2, pretrained=False)
    ckpt = torch.load("../input/pretrained/xception.pth")
    model1.load_state_dict(ckpt["state_dict"])
    model1 = model1.cuda()
    model1.eval()

    model2 = WSDAN(num_classes=2, M=8, net="xception", pretrained=False).cuda()
    ckpt = torch.load("../input/pretrained/wsdan.pth")
    model2.load_state_dict(ckpt["state_dict"])
    model2.eval()

    zhq_nm_avg = torch.Tensor([.4479, .3744, .3473]).view(1, 3, 1, 1).cuda()
    zhq_nm_std = torch.Tensor([.2537, .2502, .2424]).view(1, 3, 1, 1).cuda()

    for batch in loader:
        batch = batch.cuda(non_blocking=True)
        m1 = F.interpolate(batch, size=299, mode="bilinear")
        m1.sub_(0.5).mul_(2.0)
        m1 = model1(m1).softmax(-1)[:, 1].cpu().numpy()

        m2 = (batch - zhq_nm_avg) / zhq_nm_std
        m2, _, _ = model2(m2)
        m2 = m2.softmax(-1)[:, 1].cpu().numpy()

        prediction = 0.25 * m1 + 0.75 * m2
        loader.feedback(prediction)

    with open(csv_path) as fin, open("submission.csv", "w") as fout:
        fout.write(next(fin))
        for line in fin:
            fname = line.split(",", 1)[0]
            pred = loader.score[fname]
            print("%s,%.6f" % (fname, pred), file=fout)
Exemplo n.º 2
0
def main():
    parser = OptionParser()
    parser.add_option('-j', '--workers', dest='workers', default=16, type='int',
                      help='number of data loading workers (default: 16)')
    parser.add_option('-e', '--epochs', dest='epochs', default=80, type='int',
                      help='number of epochs (default: 80)')
    parser.add_option('-b', '--batch-size', dest='batch_size', default=16, type='int',
                      help='batch size (default: 16)')
    parser.add_option('-c', '--ckpt', dest='ckpt', default=False,
                      help='load checkpoint model (default: False)')
    parser.add_option('-v', '--verbose', dest='verbose', default=100, type='int',
                      help='show information for each <verbose> iterations (default: 100)')
    parser.add_option('--lr', '--learning-rate', dest='lr', default=1e-3, type='float',
                      help='learning rate (default: 1e-3)')
    parser.add_option('--sf', '--save-freq', dest='save_freq', default=1, type='int',
                      help='saving frequency of .ckpt models (default: 1)')
    parser.add_option('--sd', '--save-dir', dest='save_dir', default='./models/wsdan/',
                      help='saving directory of .ckpt models (default: ./models/wsdan)')
    parser.add_option('--ln', '--log-name', dest='log_name', default='train.log',
                      help='log name  (default: train.log)')
    parser.add_option('--mn', '--model-name', dest='model_name', default='model.ckpt',
                      help='model name  (default:model.ckpt)')
    parser.add_option('--init', '--initial-training', dest='initial_training', default=1, type='int',
                      help='train from 1-beginning or 0-resume training (default: 1)')
 

    (options, args) = parser.parse_args()

    ##################################
    # Initialize saving directory
    ##################################
    if not os.path.exists(options.save_dir):
        os.makedirs(options.save_dir)

    ##################################
    # Logging setting
    ##################################
    logging.basicConfig(
        filename=os.path.join( options.save_dir, options.log_name),
        filemode='w',
        format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    ##################################
    # Load dataset
    ##################################
    image_size = (256,256)
    num_classes = 4
    transform = transforms.Compose([transforms.Resize(size=image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                    std=[0.229, 0.224, 0.225])])
    train_dataset = CustomDataset(data_root='/mnt/HDD/RFW/train/data/',csv_file='data/RFW_Train40k_Images_Metada.csv',transform=transform)
    val_dataset = CustomDataset(data_root='/mnt/HDD/RFW/train/data/',csv_file='data/RFW_Val4k_Images_Metadata.csv',transform=transform)
    test_dataset = CustomDataset(data_root='/mnt/HDD/RFW/test/data/',csv_file='data/RFW_Test_Images_Metadata.csv',transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=options.batch_size, shuffle=True,num_workers=options.workers, pin_memory=True)
    validate_loader = DataLoader(val_dataset, batch_size=options.batch_size * 4, shuffle=False,num_workers=options.workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=options.batch_size * 4, shuffle=False,num_workers=options.workers, pin_memory=True)
    
    ##################################
    # Initialize model
    ##################################
    logs = {}
    start_epoch = 0
    num_attentions = 32
    feature_net = inception_v3(pretrained=True)
    net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_6e', pretrained=True)

    # feature_center: size of (#classes, #attention_maps * #channel_features)
    feature_center = torch.zeros(num_classes, num_attentions * net.num_features).to(device)
   
    if options.ckpt:
        # Load ckpt and get state_dict
        checkpoint = torch.load(options.ckpt)

        # Get epoch and some logs
        logs = checkpoint['logs']
        start_epoch = int(logs['epoch'])

        # Load weights
        state_dict = checkpoint['state_dict']
        net.load_state_dict(state_dict)
        logging.info('Network loaded from {}'.format(options.ckpt))

        # load feature center
        if 'feature_center' in checkpoint:
            feature_center = checkpoint['feature_center'].to(device)
            logging.info('feature_center loaded from {}'.format(options.ckpt))

    logging.info('Network weights save to {}'.format(options.save_dir))
    feature_net = inception_v3(pretrained=True)
 
    if options.ckpt:
        ckpt = options.ckpt

        if options.initial_training == 0:
            # Get Name (epoch)
            epoch_name = (ckpt.split('/')[-1]).split('.')[0]
            start_epoch = int(epoch_name)

        # Load ckpt and get state_dict
        checkpoint = torch.load(ckpt)
        state_dict = checkpoint['state_dict']

        # Load weights
        net.load_state_dict(state_dict)
        logging.info('Network loaded from {}'.format(options.ckpt))

        # load feature center
        if 'feature_center' in checkpoint:
            feature_center = checkpoint['feature_center'].to(torch.device("cuda"))
            logging.info('feature_center loaded from {}'.format(options.ckpt))

      ##################################
    # Use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Optimizer, LR Scheduler
    ##################################
    learning_rate = logs['lr'] if 'lr' in logs else options.lr
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

    ##################################
    # ModelCheckpoint
    ##################################
    callback_monitor = 'val_{}'.format(raw_metric.name)
    callback = ModelCheckpoint(savepath=os.path.join(options.save_dir, options.model_name),
                               monitor=callback_monitor,
                               mode='max')
    if callback_monitor in logs:
        callback.set_best_score(logs[callback_monitor])
    else:
        callback.reset()


    ##################################
    # TRAINING
    ##################################
    logging.info('')
    logging.info('Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'.
                 format(options.epochs, options.batch_size, len(train_dataset), len(val_dataset)))

    for epoch in range(start_epoch, options.epochs):
        callback.on_epoch_begin()

        logs['epoch'] = epoch + 1
        logs['lr'] = optimizer.param_groups[0]['lr']

        logging.info('Epoch {:03d}, Learning Rate {:g}'.format(epoch + 1, optimizer.param_groups[0]['lr']))

        pbar = tqdm(total=len(train_loader), unit=' batches')
        pbar.set_description('Epoch {}/{}'.format(epoch + 1, options.epochs))

        train(logs=logs,
              data_loader=train_loader,
              net=net,
              feature_center=feature_center,
              optimizer=optimizer,
              pbar=pbar)
        validate(logs=logs,
                 data_loader=validate_loader,
                 net=net,
                 pbar=pbar)

        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(logs['val_loss'])
        else:
            scheduler.step()

        callback.on_epoch_end(logs, net, feature_center=feature_center)
        pbar.close()
Exemplo n.º 3
0
def predict(image_path,
            model_param_path,
            save_path,
            img_save_name,
            resize=(224, 224),
            gen_hm=False):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        # transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
        transforms.Resize(size=(int(resize[0]), int(resize[1]))),
        # transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    image = transform(image)
    image = image.unsqueeze(0)

    net = WSDAN(num_classes=4)
    net.load_state_dict(torch.load(model_param_path))
    net.eval()

    if 'gpu' in model_param_path:
        print("please make sure your computer has a GPU")
        device = torch.device("cuda")
        try:
            net.to(device)
        except:
            print("No GPU in the environment")
    else:
        device = torch.device("cpu")

    X = image
    X = X.to(device)

    # WS-DAN
    y_pred_raw, _, attention_maps = net(X)
    attention_maps = torch.mean(attention_maps, dim=1, keepdim=True)

    # Augmentation with crop_mask
    crop_image = batch_augment(X,
                               attention_maps,
                               mode='crop',
                               theta=0.1,
                               padding_ratio=0.05)

    y_pred_crop, _, _ = net(crop_image)
    y_pred = (y_pred_raw + y_pred_crop) / 2.
    y_pred = F.softmax(y_pred)

    if gen_hm:

        attention_maps = F.upsample_bilinear(attention_maps,
                                             size=(X.size(2), X.size(3)))
        attention_maps = torch.sqrt(attention_maps.cpu() /
                                    attention_maps.max().item())

        # get heat attention maps
        heat_attention_maps = generate_heatmap(attention_maps)

        # raw_image, heat_attention, raw_attention
        raw_image = X.cpu() * STD + MEAN
        heat_attention_image = raw_image * 0.4 + heat_attention_maps * 0.6
        raw_attention_image = raw_image * attention_maps

        for batch_idx in range(X.size(0)):
            rimg = ToPILImage(raw_image[batch_idx])
            raimg = ToPILImage(raw_attention_image[batch_idx])
            haimg = ToPILImage(heat_attention_image[batch_idx])
            rimg.save(
                os.path.join(save_path, '{}_raw.jpg'.format(img_save_name)))
            raimg.save(
                os.path.join(save_path,
                             '{}_raw_atten.jpg'.format(img_save_name)))
            haimg.save(
                os.path.join(save_path,
                             '{}_heat_atten.jpg'.format(img_save_name)))

    df = pd.read_csv("../data/train.csv")
    for i in range(len(df)):
        # if df.loc[i, 'image_id'] in image_path:
        head, tail = os.path.split(image_path)
        if df.loc[i, 'image_id'] == tail[:-4]:
            label = torch.tensor(
                df.loc[i, ['healthy', 'multiple_diseases', 'rust', 'scab']])
            break
    return y_pred, label
Exemplo n.º 4
0
def main():
    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = config.eval_ckpt
    except:
        logging.info('Set ckpt for evaluation in config.py')
        return

    ##################################
    # Dataset for testing
    ##################################
    # _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
    test_dataset = CarDataset('test')
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=2,
                             pin_memory=True)
    name2label, label2name = mapping('../training_labels.csv')
    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes,
                M=config.num_attentions,
                net=config.net)

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Prediction
    ##################################
    raw_accuracy = TopKAccuracyMetric(topk=(1, 5))
    ref_accuracy = TopKAccuracyMetric(topk=(1, 5))
    raw_accuracy.reset()
    ref_accuracy.reset()

    net.eval()
    logits = []
    ids = []
    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y, id) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)
            ids.extend(id)

            # WS-DAN
            y_pred_raw, _, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X,
                                       attention_maps,
                                       mode='crop',
                                       theta=0.1,
                                       padding_ratio=0.05)

            y_pred_crop, _, _ = net(crop_image)
            y_pred = (y_pred_raw + y_pred_crop) / 2.

            # Save the predictions
            logits.append(y_pred.cpu())
            prediction = torch.argmax(torch.cat(logits, dim=0), dim=1)

            submission = pd.DataFrame(
                [ids,
                 [label2name[x] for x in prediction.numpy()]]).transpose()
            submission.columns = ['id', 'label']
            submission.to_csv(savepath + 'predictions.csv', index=False)

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps,
                                                     size=(X.size(2),
                                                           X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() /
                                            attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(
                        os.path.join(
                            savepath, '%03d_raw.jpg' %
                            (i * config.batch_size + batch_idx)))
                    raimg.save(
                        os.path.join(
                            savepath, '%03d_raw_atten.jpg' %
                            (i * config.batch_size + batch_idx)))
                    haimg.save(
                        os.path.join(
                            savepath, '%03d_heat_atten.jpg' %
                            (i * config.batch_size + batch_idx)))

            # Top K
            epoch_raw_acc = raw_accuracy(y_pred_raw, y)
            epoch_ref_acc = ref_accuracy(y_pred, y)

            # end of this batch
            batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
                epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0],
                epoch_ref_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()
Exemplo n.º 5
0
    val_loader_crop = torch.utils.data.DataLoader(
      datasets.ImageFolder(args.data_crop + VALID_IMAGES,
                          transform=data_transforms_val),
      batch_size=2, shuffle=False, num_workers=1)

device = torch.device("cuda")
print("define wsdan")
model = WSDAN(num_classes=args.num_classes, M=num_attentions, net=NET, pretrained=True)
feature_center = torch.zeros(args.num_classes, num_attentions * model.num_features).to(device)
center_loss = CenterLoss()
cross_entropy_loss = nn.CrossEntropyLoss()

if args.model:
    print("loading pretrained model")
    checkpoint = torch.load(args.model)
    model.load_state_dict(checkpoint) 
 
if use_cuda:
    print('Using GPU')
    model.cuda()
else:
    print('Using CPU')


# optimizer = optim.Adam(model.parameters(), lr=args.lr) #momentum=args.momentum

# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) #momentum=args.momentum
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

# optimizer = AdaBelief(model.parameters(), lr=args.lr, eps=1e-16, betas=(0.9,0.999), weight_decouple = True, rectify = False)
optimizer = RangerAdaBelief(model.parameters(), lr=args.lr, eps=1e-12, betas=(0.9,0.999))
Exemplo n.º 6
0
def main():
    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = sys.argv[1]
    except:
        logging.info('Usage: python3 eval.py <model.ckpt>')
        return

    ##################################
    # Dataset for testing
    ##################################
    test_dataset = CarDataset(phase='test', resize=448)
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             pin_memory=True)

    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes,
                M=32,
                net='inception_mixed_6e')

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    cudnn.benchmark = True
    net.to(device)
    net = nn.DataParallel(net)
    net.eval()

    ##################################
    # Prediction
    ##################################
    accuracy = TopKAccuracyMetric(topk=(1, 5))
    accuracy.reset()

    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)

            # WS-DAN
            y_pred_raw, feature_matrix, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X,
                                       attention_maps,
                                       mode='crop',
                                       theta=0.1)

            y_pred_crop, _, _ = net(crop_image)
            pred = (y_pred_raw + y_pred_crop) / 2.

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps,
                                                     size=(X.size(2),
                                                           X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() /
                                            attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(
                        os.path.join(savepath,
                                     '%03d_raw.jpg' % (i + batch_idx)))
                    raimg.save(
                        os.path.join(savepath,
                                     '%03d_raw_atten.jpg' % (i + batch_idx)))
                    haimg.save(
                        os.path.join(savepath,
                                     '%03d_heat_atten.jpg' % (i + batch_idx)))

            # Top K
            epoch_acc = accuracy(pred, y)

            # end of this batch
            batch_info = 'Val Acc ({:.2f}, {:.2f})'.format(
                epoch_acc[0], epoch_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()

    # show information for this epoch
    logging.info('Accuracy: %.2f, %.2f' % (epoch_acc[0], epoch_acc[1]))
Exemplo n.º 7
0
def main():
    ##################################
    # Initialize saving directory
    ##################################
    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)

    ##################################
    # Logging setting
    ##################################
    logging.basicConfig(
        filename=os.path.join(config.save_dir, config.log_name),
        filemode='w',
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    ##################################
    # Load dataset
    ##################################
    # train_dataset, validate_dataset = get_trainval_datasets(config.tag, config.image_size)
    full_train_dataset = CarDataset('train')
    n = len(full_train_dataset)
    # train_dataset, validate_dataset = torch.utils.data.random_split(full_train_dataset, [int(n*0.8), n-int(n*0.8)])
    train_dataset = full_train_dataset
    validate_dataset = full_train_dataset
    train_loader, validate_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True,
                                               num_workers=config.workers, pin_memory=True), \
                                    DataLoader(validate_dataset, batch_size=config.batch_size * 4, shuffle=False,
                                               num_workers=config.workers, pin_memory=True)
    num_classes = full_train_dataset.num_classes

    ##################################
    # Initialize model
    ##################################
    logs = {}
    start_epoch = 0
    net = WSDAN(num_classes=num_classes,
                M=config.num_attentions,
                net=config.net,
                pretrained=True)

    # feature_center: size of (#classes, #attention_maps * #channel_features)
    feature_center = torch.zeros(num_classes, config.num_attentions *
                                 net.num_features).to(device)

    if config.ckpt:
        # Load ckpt and get state_dict
        checkpoint = torch.load(config.ckpt)

        # Get epoch and some logs
        logs = checkpoint['logs']
        start_epoch = int(logs['epoch'])

        # Load weights
        state_dict = checkpoint['state_dict']
        net.load_state_dict(state_dict)
        logging.info('Network loaded from {}'.format(config.ckpt))

        # load feature center
        if 'feature_center' in checkpoint:
            feature_center = checkpoint['feature_center'].to(device)
            logging.info('feature_center loaded from {}'.format(config.ckpt))

    logging.info('Network weights save to {}'.format(config.save_dir))

    ##################################
    # Use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Optimizer, LR Scheduler
    ##################################
    learning_rate = logs['lr'] if 'lr' in logs else config.learning_rate
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=1e-5)

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=2,
                                                gamma=0.9)

    ##################################
    # ModelCheckpoint
    ##################################
    callback_monitor = 'val_{}'.format(raw_metric.name)
    callback = ModelCheckpoint(savepath=os.path.join(config.save_dir,
                                                     config.model_name),
                               monitor=callback_monitor,
                               mode='max')
    if callback_monitor in logs:
        callback.set_best_score(logs[callback_monitor])
    else:
        callback.reset()

    ##################################
    # TRAINING
    ##################################
    logging.info(
        'Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'
        .format(config.epochs, config.batch_size, len(train_dataset),
                len(validate_dataset)))
    logging.info('')

    for epoch in range(start_epoch, config.epochs):
        callback.on_epoch_begin()

        logs['epoch'] = epoch + 1
        logs['lr'] = optimizer.param_groups[0]['lr']

        logging.info('Epoch {:03d}, Learning Rate {:g}'.format(
            epoch + 1, optimizer.param_groups[0]['lr']))

        pbar = tqdm(total=len(train_loader), unit=' batches')
        pbar.set_description('Epoch {}/{}'.format(epoch + 1, config.epochs))

        train(logs=logs,
              data_loader=train_loader,
              net=net,
              feature_center=feature_center,
              optimizer=optimizer,
              pbar=pbar)
        validate(logs=logs, data_loader=validate_loader, net=net, pbar=pbar)

        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(logs['val_loss'])
        else:
            scheduler.step()

        callback.on_epoch_end(logs, net, feature_center=feature_center)
        pbar.close()
Exemplo n.º 8
0
def main(result_arr):
    logging.basicConfig(
        format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = config.eval_ckpt
    except:
        logging.info('Set ckpt for evaluation in config.py')
        return

    ##################################
    # Dataset for testing
    ##################################
    _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
                             num_workers=2, pin_memory=True)

    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net)

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Prediction
    ##################################
    raw_accuracy = TopKAccuracyMetric(topk=(1, 5))
    ref_accuracy = TopKAccuracyMetric(topk=(1, 5))
    raw_accuracy.reset()
    ref_accuracy.reset()

    net.eval()
    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)

            # WS-DAN
            y_pred_raw, _, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05)

            y_pred_crop, _, _ = net(crop_image)
            y_pred = (y_pred_raw + y_pred_crop) / 2.
            
            d = {}
            reader = csv.reader(open('/home/naman/Documents/Assignment_Job/out_dict.csv', 'r'))
            for row in reader:
                k, v = row
                d[v] = k
            
            result.append(y_pred, d[y_pred)]
            
            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx)))
                    raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx)))
                    haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx)))

            # Top K
            epoch_raw_acc = raw_accuracy(y_pred_raw, y)
            epoch_ref_acc = ref_accuracy(y_pred, y)

            # end of this batch
            batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
                epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()
Exemplo n.º 9
0
def main():
    parser = OptionParser()

    parser.add_option('--gpu',
                      '--gpu',
                      dest='GPU',
                      default=0,
                      type='int',
                      help='GPU Id (default: 0)')
    parser.add_option('--evalckpt',
                      '--eval-ckpt',
                      dest='eval_ckpt',
                      default='models/wsdan/003.ckpt',
                      help='saved models are in ckpt directory')
    parser.add_option('-b',
                      '--batch-size',
                      dest='batch_size',
                      default=64,
                      type='int',
                      help='batch size (default: 16)')
    parser.add_option('-j',
                      '--workers',
                      dest='workers',
                      default=4,
                      type='int',
                      help='number of data loading workers (default: 16)')
    parser.add_option('--na',
                      '--num-attentions',
                      dest='num_attentions',
                      default=32,
                      type='int',
                      help='number of attentions')
    parser.add_option('--cm',
                      '--confusion_matrix',
                      dest='confusion_matrix',
                      default=True,
                      help='if you want to create confusion matrix')

    (options, args) = parser.parse_args()

    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")
    try:
        ckpt = options.eval_ckpt
    except:
        logging.info('Set ckpt for evaluation options')
        return
    # Dataset for testing
    transform = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    test_dataset = CustomDataset(
        data_root='/mnt/HDD/DatasetOriginals/RFW/test/data/',
        csv_file='data/RFW_Test_Images_Metadata.csv',
        transform=transform)

    test_loader = DataLoader(test_dataset,
                             batch_size=options.batch_size * 4,
                             shuffle=False,
                             num_workers=options.workers,
                             pin_memory=True)

    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=4, M=32, net='inception_mixed_6e')

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))
    ##################################
    # use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Prediction
    ##################################
    raw_accuracy = TopKAccuracyMetric(topk=(1, 3))
    ref_accuracy = TopKAccuracyMetric(topk=(1, 3))
    raw_accuracy.reset()
    ref_accuracy.reset()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top_refined = AverageMeter('Acc@1', ':6.2f')
    net.eval()
    y_pred, y_true = [], []
    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y) in enumerate(test_loader):
            y_true += list(y.numpy())
            X = X.to(device)
            y = y.to(device)

            # WS-DAN
            y_pred_raw, _, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X,
                                       attention_maps,
                                       mode='crop',
                                       theta=0.1,
                                       padding_ratio=0.05)

            y_pred_crop, _, _ = net(crop_image)
            y_predicted = (y_pred_raw + y_pred_crop) / 2.
            _, pred = y_predicted.topk(1, 1, True, True)

            y_pred += list(pred.cpu().numpy())

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps,
                                                     size=(X.size(2),
                                                           X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() /
                                            attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(
                        os.path.join(
                            savepath, '%03d_raw.jpg' %
                            (i * options.batch_size + batch_idx)))
                    raimg.save(
                        os.path.join(
                            savepath, '%03d_raw_atten.jpg' %
                            (i * options.batch_size + batch_idx)))
                    haimg.save(
                        os.path.join(
                            savepath, '%03d_heat_atten.jpg' %
                            (i * options.batch_size + batch_idx)))

            # Top K
            epoch_raw_acc = raw_accuracy(y_pred_raw, y)
            epoch_ref_acc = ref_accuracy(y_predicted, y)
            top1.update(epoch_raw_acc[0], X.size(0))
            top_refined.update(epoch_ref_acc[0], X.size(0))

            # end of this batch
            batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
                epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0],
                epoch_ref_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()
        print(' * Raw Accuracy {top1.avg:.3f}'.format(top1=top1))
        print(' * Refined Accuracy {top1.avg:.3f}'.format(top1=top_refined))
        print(len(y_pred), len(y_true))
        if options.confusion_matrix:
            file_name = 'source/wsdan_confusion_matrix.svg'
            draw_confusion_matrix(np.asarray(y_true), np.asarray(y_pred),
                                  file_name)