Ejemplo n.º 1
0
def train(epoch):
    # lr_scheduler.step()
    model.train()
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        
        if target.numpy().any() >= 20 and target.numpy().any() < 0:
            print(target.numpy())
            continue
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        y_pred_raw, feature_matrix, attention_map = model(data)
        # Update Feature Center
        feature_center_batch = F.normalize(feature_center[target], dim=-1)
        feature_center[target] += BETA * (feature_matrix.detach() - feature_center_batch)
      
        ##################################
        # Attention Cropping
        ##################################
        with torch.no_grad():
            crop_images = batch_augment(data, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1)

        # crop images forward
        y_pred_crop, _, _ = model(crop_images)

        ##################################
        # Attention Dropping
        ##################################
        with torch.no_grad():
            drop_images = batch_augment(data, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5))

        # drop images forward
        y_pred_drop, _, _ = model(drop_images)

        # loss
        batch_loss = cross_entropy_loss(y_pred_raw, target) / 3. + \
                     cross_entropy_loss(y_pred_crop, target) / 3. + \
                     cross_entropy_loss(y_pred_drop, target) / 3. + \
                     center_loss(feature_matrix, feature_center_batch)

        # backward
        batch_loss.backward()
        optimizer.step()
      
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), batch_loss.data.item()))
Ejemplo n.º 2
0
def validate(**kwargs):
    # Retrieve training configuration
    logs = kwargs['logs']
    data_loader = kwargs['data_loader']
    net = kwargs['net']
    pbar = kwargs['pbar']

    # metrics initialization
    loss_container.reset()
    raw_metric.reset()

    # begin validation
    start_time = time.time()
    net.eval()
    with torch.no_grad():
        for i, (X, y, id) in enumerate(data_loader):
            # obtain data
            X = X.to(device)
            y = y.to(device)

            ##################################
            # Raw Image
            ##################################
            y_pred_raw, _, attention_map = net(X)

            ##################################
            # Object Localization and Refinement
            ##################################
            crop_images = batch_augment(X,
                                        attention_map,
                                        mode='crop',
                                        theta=0.1,
                                        padding_ratio=0.05)
            y_pred_crop, _, _ = net(crop_images)

            ##################################
            # Final prediction
            ##################################
            y_pred = (y_pred_raw + y_pred_crop) / 2.

            # loss
            batch_loss = cross_entropy_loss(y_pred, y)
            epoch_loss = loss_container(batch_loss.item())

            # metrics: top-1,5 error
            epoch_acc = raw_metric(y_pred, y)

    # end of validation
    logs['val_{}'.format(loss_container.name)] = epoch_loss
    logs['val_{}'.format(raw_metric.name)] = epoch_acc
    end_time = time.time()

    batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(
        epoch_loss, epoch_acc[0], epoch_acc[1])
    pbar.set_postfix_str('{}, {}'.format(logs['train_info'], batch_info))

    # write log for this epoch
    logging.info('Valid: {}, Time {:3.2f}'.format(batch_info,
                                                  end_time - start_time))
    logging.info('')
Ejemplo n.º 3
0
def write_csv(model, te_dataset, submission_df_path, options=None):
    print("Generating prediction...")
    device = get_device()
    te_dataloader = DataLoader(te_dataset,
                               batch_size=batch_size,
                               shuffle=False)
    submission_df = pd.read_csv(submission_df_path)

    test_pred = None
    model.eval()
    with torch.no_grad():
        for inputs in te_dataloader:
            inputs = inputs.to(device)

            if options is not None and options.model == 4:

                y_pred_raw, _, attention_map = model(inputs)

                crop_images = batch_augment(inputs,
                                            attention_map,
                                            mode='crop',
                                            theta=0.1,
                                            padding_ratio=0.05)
                y_pred_crop, _, _ = model(crop_images)

                y_pred = (y_pred_raw + y_pred_crop) / 2.
                outputs = y_pred
            else:

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)

            if test_pred is None:
                test_pred = outputs.data.cpu()
            else:
                test_pred = torch.cat((test_pred, outputs.data.cpu()), dim=0)

    test_pred = torch.softmax(test_pred, dim=1, dtype=float)
    submission_df[['healthy', 'multiple_diseases', 'rust', 'scab']] = test_pred

    submission_df.to_csv(options.output_root + options.output_name + '.csv',
                         index=False)
Ejemplo n.º 4
0
def validation(model, val_dataloader, criterion, epoch, options=None):
    device = get_device()
    model.to(device)
    model.eval()

    running_loss = 0.
    running_corrects = 0.
    with torch.no_grad():
        for inputs, labels, _ in val_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            #labels = labels.squeeze(-1)

            model.zero_grad()

            y_pred_raw, _, attention_map = model(inputs)

            crop_images = batch_augment(inputs,
                                        attention_map,
                                        mode='crop',
                                        theta=0.1,
                                        padding_ratio=0.05)
            y_pred_crop, _, _ = model(crop_images)

            y_pred = (y_pred_raw + y_pred_crop) / 2.
            outputs = y_pred
            if len(labels.shape) == 0:
                print("Error")
                loss = torch.tensor(0)
            else:
                loss = cross_entropy_loss(y_pred, labels)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(labels == outputs.argmax(dim=1))

    epoch_loss = running_loss / len(val_dataloader.dataset)
    epoch_acc = running_corrects.double() / len(val_dataloader.dataset)

    print('[Validation]Epoch: {}, Loss: {:.4f} Acc: {:.4f}'.format(
        epoch, epoch_loss, epoch_acc))
    return epoch_acc, epoch_loss
Ejemplo n.º 5
0
def evaluate(test_loader, epoch, model):
    model.eval()
    test_loss, correct, total, tp, fp, tn, fn = 0, 0, 0, 0, 0, 0, 0
    criterion = torch.nn.CrossEntropyLoss()

    # print(len(test_loader))
    for data in tqdm(test_loader):
        try:
            image = data["image"].cuda()
            label = data["label"].cuda()
        except (OSError):
            # print("OSError of image. ")
            continue

        y_pred_raw, _, attention_map = model(image)
        crop_images = batch_augment(image,
                                    attention_map,
                                    mode='crop',
                                    theta=0.1,
                                    padding_ratio=0.05)
        y_pred_crop, _, _ = model(crop_images)
        y_pred = (y_pred_raw + y_pred_crop) / 2.

        loss = criterion(y_pred, label)
        test_loss += loss.item()

        _, predict = y_pred.max(1)
        total += label.size(0)
        correct += predict.eq(label).sum().item()
        tp += torch.sum(predict & label)
        fp += torch.sum(predict & (1 - label))
        tn += torch.sum((1 - predict) & (1 - label))
        fn += torch.sum((1 - predict) & label)
    acc = 100. * correct / total
    precision = 100.0 * tp / (tp + fp).float()
    recall = 100.0 * tp / (tp + fn).float()
    print(
        "==> [evaluate] epoch {}, loss = {}, acc = {}, precision = {}, recall = {}"
        .format(epoch, test_loss, acc, precision, recall))
    return acc, precision, recall
Ejemplo n.º 6
0
def validation(val_loader):
    model.eval()
    validation_loss = 0
    correct = 0
    for data, target in val_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()

        ##################################
        # Raw Image
        ##################################
        y_pred_raw, _, attention_map = model(data)

        ##################################
        # Object Localization and Refinement
        ##################################
        crop_images = batch_augment(data, attention_map, mode='crop', theta=0.1, padding_ratio=0.05)
        y_pred_crop, _, _ = model(crop_images)

        ##################################
        # Final prediction
        ##################################
        y_pred = (y_pred_raw + y_pred_crop) / 2.

        # loss
        batch_loss = cross_entropy_loss(y_pred, target)

        # metrics: top-1,5 error
        # epoch_acc = raw_metric(y_pred, y)

        # get the index of the max log-probability
        pred = y_pred.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    batch_loss /= len(val_loader.dataset)
    print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        batch_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
Ejemplo n.º 7
0
def train(**kwargs):
    # Retrieve training configuration
    logs = kwargs['logs']
    data_loader = kwargs['data_loader']
    net = kwargs['net']
    feature_center = kwargs['feature_center']
    optimizer = kwargs['optimizer']
    pbar = kwargs['pbar']

    # metrics initialization
    loss_container.reset()
    raw_metric.reset()
    crop_metric.reset()
    drop_metric.reset()

    # begin training
    start_time = time.time()
    net.train()
    for i, (X, y) in enumerate(data_loader):
        optimizer.zero_grad()

        # obtain data for training
        X = X.to(device)
        y = y.to(device)

        ##################################
        # Raw Image
        ##################################
        # raw images forward
        y_pred_raw, feature_matrix, attention_map = net(X)

        # Update Feature Center
        feature_center_batch = F.normalize(feature_center[y], dim=-1)
        feature_center[y] += 5e-2 * (feature_matrix.detach() - feature_center_batch)

        ##################################
        # Attention Cropping
        ##################################
        with torch.no_grad():
            crop_images = batch_augment(X, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1)

        # crop images forward
        y_pred_crop, _, _ = net(crop_images)

        ##################################
        # Attention Dropping
        ##################################
        with torch.no_grad():
            drop_images = batch_augment(X, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5))

        # drop images forward
        y_pred_drop, _, _ = net(drop_images)

        # loss
        batch_loss = cross_entropy_loss(y_pred_raw, y) / 3. + \
                     cross_entropy_loss(y_pred_crop, y) / 3. + \
                     cross_entropy_loss(y_pred_drop, y) / 3. + \
                     center_loss(feature_matrix, feature_center_batch)

        # backward
        batch_loss.backward()
        optimizer.step()

        # metrics: loss and top-1,5 error
        with torch.no_grad():
            epoch_loss = loss_container(batch_loss.item())
            epoch_raw_acc = raw_metric(y_pred_raw, y)
            epoch_crop_acc = crop_metric(y_pred_crop, y)
            epoch_drop_acc = drop_metric(y_pred_drop, y)

        # end of this batch
        batch_info = 'Loss {:.4f}, Raw Acc ({:.2f}, {:.2f}), Crop Acc ({:.2f}, {:.2f}), Drop Acc ({:.2f}, {:.2f})'.format(
            epoch_loss, epoch_raw_acc[0], epoch_raw_acc[1],
            epoch_crop_acc[0], epoch_crop_acc[1], epoch_drop_acc[0], epoch_drop_acc[1])
        pbar.update()
        pbar.set_postfix_str(batch_info)

    # end of this epoch
    logs['train_{}'.format(loss_container.name)] = epoch_loss
    logs['train_raw_{}'.format(raw_metric.name)] = epoch_raw_acc
    logs['train_crop_{}'.format(crop_metric.name)] = epoch_crop_acc
    logs['train_drop_{}'.format(drop_metric.name)] = epoch_drop_acc
    logs['train_info'] = batch_info
    end_time = time.time()

    # write log for this epoch
    logging.info('Train: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))
Ejemplo n.º 8
0
def predict(image_path,
            model_param_path,
            save_path,
            img_save_name,
            resize=(224, 224),
            gen_hm=False):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        # transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
        transforms.Resize(size=(int(resize[0]), int(resize[1]))),
        # transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    image = transform(image)
    image = image.unsqueeze(0)

    net = WSDAN(num_classes=4)
    net.load_state_dict(torch.load(model_param_path))
    net.eval()

    if 'gpu' in model_param_path:
        print("please make sure your computer has a GPU")
        device = torch.device("cuda")
        try:
            net.to(device)
        except:
            print("No GPU in the environment")
    else:
        device = torch.device("cpu")

    X = image
    X = X.to(device)

    # WS-DAN
    y_pred_raw, _, attention_maps = net(X)
    attention_maps = torch.mean(attention_maps, dim=1, keepdim=True)

    # Augmentation with crop_mask
    crop_image = batch_augment(X,
                               attention_maps,
                               mode='crop',
                               theta=0.1,
                               padding_ratio=0.05)

    y_pred_crop, _, _ = net(crop_image)
    y_pred = (y_pred_raw + y_pred_crop) / 2.
    y_pred = F.softmax(y_pred)

    if gen_hm:

        attention_maps = F.upsample_bilinear(attention_maps,
                                             size=(X.size(2), X.size(3)))
        attention_maps = torch.sqrt(attention_maps.cpu() /
                                    attention_maps.max().item())

        # get heat attention maps
        heat_attention_maps = generate_heatmap(attention_maps)

        # raw_image, heat_attention, raw_attention
        raw_image = X.cpu() * STD + MEAN
        heat_attention_image = raw_image * 0.4 + heat_attention_maps * 0.6
        raw_attention_image = raw_image * attention_maps

        for batch_idx in range(X.size(0)):
            rimg = ToPILImage(raw_image[batch_idx])
            raimg = ToPILImage(raw_attention_image[batch_idx])
            haimg = ToPILImage(heat_attention_image[batch_idx])
            rimg.save(
                os.path.join(save_path, '{}_raw.jpg'.format(img_save_name)))
            raimg.save(
                os.path.join(save_path,
                             '{}_raw_atten.jpg'.format(img_save_name)))
            haimg.save(
                os.path.join(save_path,
                             '{}_heat_atten.jpg'.format(img_save_name)))

    df = pd.read_csv("../data/train.csv")
    for i in range(len(df)):
        # if df.loc[i, 'image_id'] in image_path:
        head, tail = os.path.split(image_path)
        if df.loc[i, 'image_id'] == tail[:-4]:
            label = torch.tensor(
                df.loc[i, ['healthy', 'multiple_diseases', 'rust', 'scab']])
            break
    return y_pred, label
Ejemplo n.º 9
0
def main():
    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = config.eval_ckpt
    except:
        logging.info('Set ckpt for evaluation in config.py')
        return

    ##################################
    # Dataset for testing
    ##################################
    # _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
    test_dataset = CarDataset('test')
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=2,
                             pin_memory=True)
    name2label, label2name = mapping('../training_labels.csv')
    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes,
                M=config.num_attentions,
                net=config.net)

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Prediction
    ##################################
    raw_accuracy = TopKAccuracyMetric(topk=(1, 5))
    ref_accuracy = TopKAccuracyMetric(topk=(1, 5))
    raw_accuracy.reset()
    ref_accuracy.reset()

    net.eval()
    logits = []
    ids = []
    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y, id) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)
            ids.extend(id)

            # WS-DAN
            y_pred_raw, _, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X,
                                       attention_maps,
                                       mode='crop',
                                       theta=0.1,
                                       padding_ratio=0.05)

            y_pred_crop, _, _ = net(crop_image)
            y_pred = (y_pred_raw + y_pred_crop) / 2.

            # Save the predictions
            logits.append(y_pred.cpu())
            prediction = torch.argmax(torch.cat(logits, dim=0), dim=1)

            submission = pd.DataFrame(
                [ids,
                 [label2name[x] for x in prediction.numpy()]]).transpose()
            submission.columns = ['id', 'label']
            submission.to_csv(savepath + 'predictions.csv', index=False)

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps,
                                                     size=(X.size(2),
                                                           X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() /
                                            attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(
                        os.path.join(
                            savepath, '%03d_raw.jpg' %
                            (i * config.batch_size + batch_idx)))
                    raimg.save(
                        os.path.join(
                            savepath, '%03d_raw_atten.jpg' %
                            (i * config.batch_size + batch_idx)))
                    haimg.save(
                        os.path.join(
                            savepath, '%03d_heat_atten.jpg' %
                            (i * config.batch_size + batch_idx)))

            # Top K
            epoch_raw_acc = raw_accuracy(y_pred_raw, y)
            epoch_ref_acc = ref_accuracy(y_pred, y)

            # end of this batch
            batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
                epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0],
                epoch_ref_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()
Ejemplo n.º 10
0
def main():
    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = sys.argv[1]
    except:
        logging.info('Usage: python3 eval.py <model.ckpt>')
        return

    ##################################
    # Dataset for testing
    ##################################
    test_dataset = CarDataset(phase='test', resize=448)
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             pin_memory=True)

    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes,
                M=32,
                net='inception_mixed_6e')

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    cudnn.benchmark = True
    net.to(device)
    net = nn.DataParallel(net)
    net.eval()

    ##################################
    # Prediction
    ##################################
    accuracy = TopKAccuracyMetric(topk=(1, 5))
    accuracy.reset()

    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)

            # WS-DAN
            y_pred_raw, feature_matrix, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X,
                                       attention_maps,
                                       mode='crop',
                                       theta=0.1)

            y_pred_crop, _, _ = net(crop_image)
            pred = (y_pred_raw + y_pred_crop) / 2.

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps,
                                                     size=(X.size(2),
                                                           X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() /
                                            attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(
                        os.path.join(savepath,
                                     '%03d_raw.jpg' % (i + batch_idx)))
                    raimg.save(
                        os.path.join(savepath,
                                     '%03d_raw_atten.jpg' % (i + batch_idx)))
                    haimg.save(
                        os.path.join(savepath,
                                     '%03d_heat_atten.jpg' % (i + batch_idx)))

            # Top K
            epoch_acc = accuracy(pred, y)

            # end of this batch
            batch_info = 'Val Acc ({:.2f}, {:.2f})'.format(
                epoch_acc[0], epoch_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()

    # show information for this epoch
    logging.info('Accuracy: %.2f, %.2f' % (epoch_acc[0], epoch_acc[1]))
Ejemplo n.º 11
0
def train(model,
          tr_dataloader,
          criterion,
          optimizer,
          epoch,
          options=None,
          feature_center=None):
    since = time.time()
    device = get_device()
    model.train()

    model.to(device)

    running_loss = 0.
    running_corrects = 0.
    for idx, (inputs, labels, _) in enumerate(tr_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        #labels = labels.squeeze(-1)
        optimizer.zero_grad()
        model.zero_grad()

        y_pred_raw, feature_matrix, attention_map = model(inputs)

        # if len(labels.shape
        feature_center_batch = F.normalize(feature_center[labels], dim=-1)
        feature_center[labels] += 0.05 * (feature_matrix.detach() -
                                          feature_center_batch)

        ##################################
        # Attention Cropping
        ##################################
        with torch.no_grad():
            crop_images = batch_augment(inputs,
                                        attention_map[:, :1, :, :],
                                        mode='crop',
                                        theta=(0.4, 0.6),
                                        padding_ratio=0.1)

        # crop images forward
        y_pred_crop, _, _ = model(crop_images)

        ##################################
        # Attention Dropping
        ##################################
        with torch.no_grad():
            drop_images = batch_augment(inputs,
                                        attention_map[:, 1:, :, :],
                                        mode='drop',
                                        theta=(0.2, 0.5))

        # drop images forward
        y_pred_drop, _, _ = model(drop_images)
        outputs = (y_pred_raw + y_pred_crop + y_pred_drop) / 3.
        # loss
        loss = cross_entropy_loss(y_pred_raw, labels) / 3. + \
                    cross_entropy_loss(y_pred_crop, labels) / 3. + \
                    cross_entropy_loss(y_pred_drop, labels) / 3. + \
                    center_loss(feature_matrix, feature_center_batch)

        loss.backward()
        optimizer.step()

        batch_loss = loss.item() * inputs.size(0)
        batch_corrects = torch.sum(labels == outputs.argmax(dim=1))

        running_loss += batch_loss
        running_corrects += batch_corrects

        if idx % 9 == 1:
            print('[Train]Epoch: {}, idx: {}, Loss: {:.4f} Acc: {:.4f}'.format(
                epoch, idx, batch_loss / len(inputs),
                batch_corrects.float() / len(inputs)))

    epoch_loss = running_loss / len(tr_dataloader.dataset)
    epoch_acc = running_corrects.double() / len(tr_dataloader.dataset)

    print('[Train]Epoch: {}, Loss: {:.4f} Acc: {:.4f}'.format(
        epoch, epoch_loss, epoch_acc))

    return epoch_acc, epoch_loss
Ejemplo n.º 12
0
def main(result_arr):
    logging.basicConfig(
        format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = config.eval_ckpt
    except:
        logging.info('Set ckpt for evaluation in config.py')
        return

    ##################################
    # Dataset for testing
    ##################################
    _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
                             num_workers=2, pin_memory=True)

    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net)

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Prediction
    ##################################
    raw_accuracy = TopKAccuracyMetric(topk=(1, 5))
    ref_accuracy = TopKAccuracyMetric(topk=(1, 5))
    raw_accuracy.reset()
    ref_accuracy.reset()

    net.eval()
    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)

            # WS-DAN
            y_pred_raw, _, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05)

            y_pred_crop, _, _ = net(crop_image)
            y_pred = (y_pred_raw + y_pred_crop) / 2.
            
            d = {}
            reader = csv.reader(open('/home/naman/Documents/Assignment_Job/out_dict.csv', 'r'))
            for row in reader:
                k, v = row
                d[v] = k
            
            result.append(y_pred, d[y_pred)]
            
            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx)))
                    raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx)))
                    haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx)))

            # Top K
            epoch_raw_acc = raw_accuracy(y_pred_raw, y)
            epoch_ref_acc = ref_accuracy(y_pred, y)

            # end of this batch
            batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
                epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()
Ejemplo n.º 13
0
def model_train(model, input, eval_loader, test_loader, cfg):
    model.train()
    writer = SummaryWriter(log_dir=cfg.log_path)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.0001,
                                 weight_decay=0.0005,
                                 amsgrad=True)
    current_epoch = 0
    global_step = 0
    if cfg.load_ckp == True:
        model, optimizer, current_epoch, global_step, loss = load_model(
            model, optimizer, cfg.ckp_path)

    feature_center = torch.zeros(2, cfg.num_attentions * model.num_features)
    center_loss = CenterLoss()

    for epoch in range(current_epoch, cfg.NUM_EPOCHS, 1):
        running_loss = 0.0
        _time = time()
        for i, data in enumerate(tqdm(input)):
            if i == len(input) - 1:
                break
            try:
                image = data["image"].cuda()
                label = data["label"].cuda()
            except (OSError):
                print("OSError of image. ")
                continue
            optimizer.zero_grad()
            y_pred_raw, feature_matrix, attention_map = model(image)
            '''
            # Update Feature Center            
            feature_center_batch = torch.nn.functional.normalize(feature_center[label], dim=-1)
            print(feature_center[label].shape, feature_matrix.detach().shape, feature_center_batch.shape)
            feature_center_batch[label] += cfg.beta * (feature_matrix.detach() - feature_center_batch)
            '''

            # Attention Cropping
            with torch.no_grad():
                crop_images = batch_augment(image,
                                            attention_map[:, :1, :, :],
                                            mode='crop',
                                            theta=(0.4, 0.6),
                                            padding_ratio=0.1)

            # crop images forward
            y_pred_crop, _, _ = model(crop_images)
            '''
            # Attention Dropping
            with torch.no_grad():
                drop_images = batch_augment(image, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5))

            # drop images forward
            y_pred_drop, _, _ = model(drop_images)
            '''

            loss = criterion(y_pred_raw, label) / 3. + \
                         criterion(y_pred_crop, label) / 3
            #criterion(y_pred_drop, label) / 3. + \
            #0  #center_loss(feature_matrix, feature_center_batch)

            #print(loss)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 200 == 0:
                batch_time = time() - _time
                print(
                    "==> [train] epoch {}, batch {}, global_step {}. loss for 10 batches: {}, time for 10 batches: {}s"
                    .format(epoch, i, global_step, running_loss, batch_time))
                writer.add_scalar("scalar/loss", running_loss, global_step,
                                  time())
                running_loss = 0.0
                _time = time()
            global_step += 1
        # TODO add save condition eg. acc

        if epoch % cfg.evaluate_epoch == 0:
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "global_step": global_step,
                    'loss': loss,
                },
                os.path.join(cfg.save_path,
                             "train_epoch_" + str(epoch) + ".tar"))
            print("==> [eval] on train dataset")
            acc_on_train, precision_on_train, recall_on_train = evaluate(
                train_loader, epoch, model)
            print("==> [eval] on valid dataset")
            acc_on_valid, precision_on_valid, recall_on_valid = evaluate(
                eval_loader, epoch, model)
            writer.add_scalar("scalar/accuracy_on_train", acc_on_train,
                              global_step, time())
            writer.add_scalar("scalar/accuracy_on_valid", acc_on_valid,
                              global_step, time())
            writer.add_scalar("scalar/precisoin_on_train", precision_on_train,
                              global_step, time())
            writer.add_scalar("scalar/precision_on_valid", precision_on_valid,
                              global_step, time())
            writer.add_scalar("scalar/recall_on_train", recall_on_train,
                              global_step, time())
            writer.add_scalar("scalar/recall_on_valid", recall_on_valid,
                              global_step, time())

    writer.close()
    print("Finish training.")