Exemplo n.º 1
0
                        default='cuda',
                        help="Device: 'cuda' or 'cpu'")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available()
                          and args.device == 'cuda' else "cpu")
    # attributes variable contains labels for the categories in the dataset and mapping between string names and IDs
    attributes = AttributesDataset(args.attributes_file)

    # during validation we use only tensor and normalization transforms
    val_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    test_dataset = FashionDataset('./val.csv', attributes, val_transform)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=64,
                                 shuffle=False,
                                 num_workers=8)

    model = MultiOutputModel(
        n_color_classes=attributes.num_colors,
        n_gender_classes=attributes.num_genders,
        n_article_classes=attributes.num_articles).to(device)

    # Visualization of the trained model
    visualize_grid(model,
                   test_dataloader,
                   attributes,
                   device,
                   checkpoint=args.checkpoint)
Exemplo n.º 2
0
    train_dataset = FashionDataset('./fashion-product-images/train.csv',
                                   attributes, train_transform)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)

    val_dataset = FashionDataset('./fashion-product-images/val.csv',
                                 attributes, val_transform)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)

    model = MultiOutputModel(n_color_classes=attributes.num_colors,
                             n_gender_classes=attributes.num_genders,
                             n_article_classes=attributes.num_articles)\
                            .to(device)

    optimizer = torch.optim.Adam(model.parameters())

    logdir = os.path.join('./logs/', get_cur_time())
    savedir = os.path.join('./checkpoints/', get_cur_time())
    os.makedirs(logdir, exist_ok=True)
    os.makedirs(savedir, exist_ok=True)
    logger = SummaryWriter(logdir)

    n_train_samples = len(train_dataloader)

    # Uncomment rows below to see example images with ground truth labels in val dataset and all the labels:
    visualize_grid(model,
                   val_dataloader,
Exemplo n.º 3
0
proba = model.predict(image)[0]
idxs = np.argsort(proba)[::-1][:2]

for (i, j) in enumerate(idxs):
    label = "{}".format(mlb.classes_[j])
    print(label)

if __name__ == '__main__':

    checkpoint = '/checkpoint-000050.pth'
    attributes_file = '/styles.csv'
    device = 'cpu'
    #imagg=open(args.inputimage())
    attributes = AttributesDataset(attributes_file)
    model = MultiOutputModel(
        n_color_classes=attributes.num_colors,
        n_gender_classes=attributes.num_genders,
        n_article_classes=attributes.num_articles).to(device)
    #enter path to input image
    res = DATASET_PATH

    val_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    img = Image.open(res)
    img = val_transform(img)
    img = img.view(-1, 3, img.shape[0], img.shape[1])
    name = checkpoint
    print('Restoring checkpoint: {}'.format(name))
    model.load_state_dict(torch.load(name, map_location='cpu'))
    #checkpoint_load(model, checkpoint=args.checkpoint)
    model.eval()