コード例 #1
0
def deploy(args, data_loader):
    model = Network(k=args.network_k,
                    att_type=args.network_att_type,
                    kernel3=args.kernel3,
                    width=args.network_width,
                    dropout=args.network_dropout,
                    compensate=True,
                    norm=args.norm,
                    inp_channels=args.input_channels)

    print(model)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    checkpoint_path = os.path.join(args.logdir, 'best_checkpoint.pth')
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        raise Exception('Couldnt load checkpoint.')

    df = pd.DataFrame(columns=['img', 'label', 'pred'])

    with tqdm(enumerate(data_loader)) as pbar:
        for i, (images, labels) in pbar:
            raw_label = labels
            raw_images = images
            if torch.cuda.is_available():
                images = images.cuda()
                labels = labels.cuda()

            images.requires_grad = True
            # Forward pass
            outputs, att, localised = model(images, True)
            localised = F.softmax(localised.data, 3)[..., 1]
            predicted = torch.argmax(outputs.data, 1)
            saliency = torch.autograd.grad(outputs[:, 1].sum(), images)[0].data

            localised = localised[0].cpu().numpy()
            saliency = torch.sqrt((saliency[0]**2).mean(0)).cpu().numpy()
            raw_img = np.transpose(raw_images.numpy(), (0, 2, 3, 1)).squeeze()
            np.save(os.path.join(args.outpath, 'pred_{}.npy'.format(i)),
                    localised)

            np.save(os.path.join(args.outpath, 'sal_{}.npy'.format(i)),
                    saliency)

            df.loc[len(df)] = [
                i,
                raw_label.numpy().squeeze(),
                predicted.cpu().numpy().squeeze()
            ]

    df.to_csv(os.path.join(args.outpath, 'pred.csv'), index=False)
    print('done - stopping now')
コード例 #2
0
ファイル: train.py プロジェクト: suyanzhou626/FC_densenet
def main(params):

    print("Loading dataset ... ")

    with open(params['train_data_pkl'], 'rb') as f:
        train_data = pkl.load(f)
    with open(params['train_anno_pkl'], 'rb') as f:
        train_anno = pkl.load(f)
    """
    with open(params['val_data_pkl'], 'rb') as f:
        val_data = pkl.load(f)
    with open(params['val_anno_pkl'], 'rb') as f:
        val_anno = pkl.load(f)
    """

    # Train dataset and Train dataloader
    train_data = np.transpose(train_data, (0, 3, 1, 2))
    train_dataset = torch.utils.data.TensorDataset(
        torch.FloatTensor(train_data), torch.LongTensor(train_anno))

    train_loader = dataloader.DataLoader(train_dataset,
                                         params['batch_size'],
                                         shuffle=True,
                                         collate_fn=collate_fn)
    """
    # Validation dataset and Validation dataloader
    val_data = np.transpose(val_data, (0, 3, 1, 2))
    val_dataset = torch.utils.data.TensorDataset(
        torch.FloatTensor(val_data), torch.LongTensor(val_anno))
        val_loader = dataloader.DataLoader(
            val_dataset, params['batch_size'], collate_fn=collate_fn)
    """

    # the number of layers in each dense block
    n_layers_list = [4, 5, 7, 10, 12, 15, 12, 10, 7, 5, 4]

    print("Constructing the network ... ")
    # Define the network
    densenet = Network(n_layers_list, 5).to(device)

    if os.path.isfile(params['model_from']):
        print("Starting from the saved model")
        densenet.load_state_dict(torch.load(params['model_from']))
    else:
        print("Couldn't find the saved model")
        print("Starting from the bottom")

    print("Training the model ...")
    # hyperparameter, optimizer, criterion
    learning_rate = params['lr']
    optimizer = torch.optim.RMSprop(densenet.parameters(),
                                    learning_rate,
                                    weight_decay=params['l2_reg'])
    criterion = nn.CrossEntropyLoss()

    for epoch in range(params['max_epoch']):
        for i, (img, label) in enumerate(train_loader):
            img = img.to(device)
            label = label.to(device)

            # forward-propagation
            pred = densenet(img)

            # flatten for all pixel
            pred = pred.view((-1, params['num_answers']))
            label = label.view((-1))

            # get loss
            loss = criterion(pred, label)

            # back-propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print("Epoch: %d, Steps:[%d/%d], Loss: %.4f" %
                  (epoch, i, len(train_loader), loss.data))

        learning_rate *= 0.995
        optimizer = torch.optim.RMSprop(densenet.parameters(),
                                        learning_rate,
                                        weight_decay=params['l2_reg'])

        if (epoch + 1) % 10 == 0:
            print("Saved the model")
            torch.save(densenet.state_dict(), params['model_save'])