Exemple #1
0
def main():
    args.task_selection = args.task_selection.split(',')

    torch.manual_seed(args.seed)

    # LOAD DATASET
    stat_file = args.stat_file
    with open(stat_file, 'r') as f:
        data = pickle.load(f)
        mean, std = data['mean'], data['std']
        mean = [float(m) for m in mean]
        std = [float(s) for s in std]
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                normalize,
    ])

    if not args.shape_dataset:
        if args.task_selection is not None:
            classes = args.task_selection
        elif args.office_dataset:
            classes = ['style', 'genre']
        elif args.bam_dataset:
            classes = ['content', 'emotion', 'media']
        else:
            classes = ['artist_name', 'genre', 'style', 'technique', 'century']
        valset = Wikiart(path_to_info_file=args.val_file, path_to_images=args.im_path,
                         classes=classes, transform=val_transform)
        trainset = Wikiart(path_to_info_file=args.train_file, path_to_images=args.im_path,
                           classes=classes, transform=train_transform)
    else:
        if args.task_selection is not None:
            classes = args.task_selection
        else:
            classes = ['shape', 'n_shapes', 'color_shape', 'color_background']
        valset = ShapeDataset(root_dir='/export/home/kschwarz/Documents/Data/Geometric_Shapes', split='val',
                              classes=classes, transform=val_transform)
        trainset = ShapeDataset(root_dir='/export/home/kschwarz/Documents/Data/Geometric_Shapes', split='train',
                                classes=classes, transform=train_transform)

    if not trainset.labels_to_ints == valset.labels_to_ints:
        print('validation set and training set int labels do not match. Use int conversion of trainset')
        print(trainset.labels_to_ints, valset.labels_to_ints)
        valset.labels_to_ints = trainset.labels_to_ints.copy()

    num_labels = [len(trainset.labels_to_ints[c]) for c in classes]

    # PARAMETERS
    use_cuda = args.use_gpu and torch.cuda.is_available()
    device_nb = args.device
    if use_cuda:
        torch.cuda.set_device(device_nb)
        torch.cuda.manual_seed_all(args.seed)

    # INITIALIZE NETWORK
    if args.model.lower() not in ['mobilenet_v2', 'vgg16_bn']:
        raise NotImplementedError('Unknown Model {}\n\t+ Choose from: [mobilenet_v2, vgg16_bn].'
                                  .format(args.model))
    elif args.model.lower() == 'mobilenet_v2':
        featurenet = mobilenet_v2(pretrained=True)
    elif args.model.lower() == 'vgg16_bn':
        featurenet = vgg16_bn(pretrained=True)
    if args.not_narrow:
        bodynet = featurenet
    else:
        bodynet = narrownet(featurenet, dim_feature_out=args.feature_dim)
    net = OctopusNet(bodynet, n_labels=num_labels)
    n_parameters = sum([p.data.nelement() for p in net.parameters() if p.requires_grad])
    if use_cuda:
        net = net.cuda()
    print('Using {}\n\t+ Number of params: {}'.format(str(bodynet).split('(')[0], n_parameters))

    # LOG/SAVE OPTIONS
    log_interval = args.log_interval
    log_dir = args.log_dir
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    # tensorboard summary writerR
    timestamp = time.strftime('%m-%d-%H-%M')
    if args.shape_dataset:
        expname = timestamp + '_ShapeDataset_' + str(bodynet).split('(')[0]
    else:
        expname = timestamp + '_' + str(bodynet).split('(')[0]
    if args.exp_name is not None:
        expname = expname + '_' + args.exp_name
    log = TBPlotter(os.path.join(log_dir, 'tensorboard', expname))
    log.print_logdir()

    # allow auto-tuner to find best algorithm for the hardware
    cudnn.benchmark = True

    write_config(args, os.path.join(log_dir, expname))

    # ININTIALIZE TRAINING
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, threshold=1e-1, verbose=True)
    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion = criterion.cuda()

    kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}
    trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
    valloader = DataLoader(valset, batch_size=args.batch_size, shuffle=True, **kwargs)

    # optionally resume from a checkpoint
    start_epoch = 1
    if args.chkpt is not None:
        if os.path.isfile(args.chkpt):
            print("=> loading checkpoint '{}'".format(args.chkpt))
            checkpoint = torch.load(args.chkpt, map_location=lambda storage, loc: storage)
            start_epoch = checkpoint['epoch']
            best_acc_score = checkpoint['best_acc_score']
            best_acc = checkpoint['acc']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.chkpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.chkpt))

    def train(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0])).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                # in case of None labels
                mask = t != -1
                if mask.sum() == 0:
                    continue
                o, t, p = o[mask], t[mask], p[mask]
                loss += criterion(o, t)
                # measure class accuracy and record loss
                class_acc[i].update((torch.sum(p == t).type(torch.FloatTensor) / t.size(0)).data)
            accs.update(torch.mean(torch.stack([class_acc[i].val for i in range(len(classes))])), target[0].size(0))
            losses.update(loss.data, target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'Acc: {:.2f}% ({:.2f}%)'.format(
                    epoch, batch_idx * len(target), len(trainloader.dataset),
                    float(losses.val), float(losses.avg),
                           float(accs.val) * 100., float(accs.avg) * 100.))
                print('\t' + '\n\t'.join(['{}: {:.2f}%'.format(classes[i], float(class_acc[i].val) * 100.)
                                          for i in range(len(classes))]))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('acc', float(accs.avg), epoch, test=False)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=False)

    def test(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0])).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                # in case of None labels
                mask = t != -1
                if mask.sum() == 0:
                    continue
                o, t, p = o[mask], t[mask], p[mask]
                loss += criterion(o, t)
                # measure class accuracy and record loss
                class_acc[i].update((torch.sum(p == t).type(torch.FloatTensor) / t.size(0)).data)
            accs.update(torch.mean(torch.stack([class_acc[i].val for i in range(len(classes))])), target[0].size(0))
            losses.update(loss.data, target[0].size(0))

        score = accs.avg - torch.std(torch.stack([class_acc[i].avg for i in range(
            len(classes))])) / accs.avg  # compute mean - std/mean as measure for accuracy
        print('\nVal set: Average loss: {:.4f} Average acc {:.2f}% Acc score {:.2f} LR: {:.6f}'
              .format(float(losses.avg), float(accs.avg) * 100., float(score), optimizer.param_groups[-1]['lr']))
        print('\t' + '\n\t'.join(['{}: {:.2f}%'.format(classes[i], float(class_acc[i].avg) * 100.)
                                  for i in range(len(classes))]))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('acc', float(accs.avg), epoch, test=True)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=True)
        return losses.avg.cpu().numpy(), float(score), float(accs.avg), [float(class_acc[i].avg) for i in
                                                                         range(len(classes))]

    if start_epoch == 1:  # compute baseline:
        _, best_acc_score, best_acc, _ = test(epoch=0)
    else:  # checkpoint was loaded
        best_acc_score = best_acc_score
        best_acc = best_acc

    for epoch in range(start_epoch, args.epochs + 1):
        # train for one epoch
        train(epoch)
        # evaluate on validation set
        val_loss, val_acc_score, val_acc, val_class_accs = test(epoch)
        scheduler.step(val_loss)

        # remember best acc and save checkpoint
        is_best = val_acc_score > best_acc_score
        best_acc_score = max(val_acc_score, best_acc_score)
        save_checkpoint({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_acc_score': best_acc_score,
            'acc': val_acc,
            'class_acc': {c: a for c, a in zip(classes, val_class_accs)}
        }, is_best, expname, directory=log_dir)

        if val_acc > best_acc:
            shutil.copyfile(os.path.join(log_dir, expname + '_checkpoint.pth.tar'),
                            os.path.join(log_dir, expname + '_model_best_mean_acc.pth.tar'))
        best_acc = max(val_acc, best_acc)

        if optimizer.param_groups[-1]['lr'] < 1e-5:
            print('Learning rate reached minimum threshold. End training.')
            break

    # report best values
    try:
        best = torch.load(os.path.join(log_dir, expname + '_model_best.pth.tar'), map_location=lambda storage, loc: storage)
    except IOError:         # could be only one task
        best = torch.load(os.path.join(log_dir, expname + '_model_best_mean_acc.pth.tar'), map_location=lambda storage, loc: storage)
    print('Finished training after epoch {}:\n\tbest acc score: {}\n\tacc: {}\n\t class acc: {}'
          .format(best['epoch'], best['best_acc_score'], best['acc'], best['class_acc']))
    print('Best model mean accuracy: {}'.format(best_acc))

    try:
        shutil.copyfile(os.path.join(log_dir, expname + '_model_best.pth.tar'),
                        os.path.join('models', expname + '_model_best.pth.tar'))
    except IOError:  # could be only one task
        shutil.copyfile(os.path.join(log_dir, expname + '_model_best_mean_acc.pth.tar'),
                        os.path.join('models', expname + '_model_best.pth.tar'))
Exemple #2
0
def train_embedder(embedder,
                   feature,
                   lr=1e-3,
                   batch_size=100,
                   experiment_id=None,
                   random_state=123):
    # log and saving options
    exp_name = 'MapNet_embedder'

    if experiment_id is not None:
        exp_name = experiment_id + '_' + exp_name

    log = TBPlotter(os.path.join('runs/embedder', 'tensorboard', exp_name))
    log.print_logdir()

    outpath_model = os.path.join('runs/embedder/models')
    if not os.path.isdir(outpath_model):
        os.makedirs(outpath_model)

    # general
    use_cuda = torch.cuda.is_available()
    N = len(feature)

    idx_train, idx_test = train_test_split(range(N),
                                           test_size=0.2,
                                           random_state=random_state,
                                           shuffle=True)
    kwargs = {
        'num_workers': 4,
        'drop_last': True
    } if use_cuda else {
        'drop_last': True
    }
    train_loader = DataLoader(
        IndexDataset(feature[idx_train]),
        batch_size=
        batch_size,  # careful, returned index is now for idx_train selection
        **kwargs)
    test_loader = DataLoader(
        IndexDataset(feature[idx_test]),
        batch_size=
        batch_size,  # careful, returned index is now for idx_test selection
        **kwargs)

    if use_cuda:
        embedder = embedder.cuda()
    stop_early_compression = 3
    stop_early_exaggeration = 1
    early_exaggeration_factor = 1

    optimizer = torch.optim.Adam(embedder.parameters(),
                                 lr=lr,
                                 weight_decay=2e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           patience=10,
                                                           threshold=1e-3,
                                                           verbose=True)

    train_criterion = TSNELoss(
        N=len(idx_train),
        early_exaggeration_fac=early_exaggeration_factor,
        use_cuda=use_cuda)
    test_criterion = TSNELoss(N=len(idx_test),
                              early_exaggeration_fac=early_exaggeration_factor,
                              use_cuda=use_cuda)
    print('Compute beta for KL-Loss...')
    train_criterion._compute_beta(torch.from_numpy(feature[idx_train]).cuda())
    test_criterion._compute_beta(torch.from_numpy(feature[idx_test]).cuda())
    print('done...')

    log_interval = 10

    def train(epoch):
        losses = AverageMeter()
        # if epoch == stop_early_compression:
        #     print('stop early compression')

        # switch to train mode
        embedder.train()
        for batch_idx, (fts, idx) in enumerate(train_loader):
            fts = torch.autograd.Variable(
                fts.cuda()) if use_cuda else torch.autograd.Variable(fts)
            outputs = embedder(fts)
            loss = train_criterion(fts, outputs, idx)

            losses.update(loss.data, len(idx))

            # if epoch <= stop_early_compression:
            #     optimizer.weight_decay = 0
            #     loss += 0.01 * outputs.norm(p=2, dim=1).mean().type_as(loss)

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'LR: {:.6f}'.format(epoch, (batch_idx + 1) * len(idx),
                                          len(train_loader.sampler),
                                          float(losses.val), float(losses.avg),
                                          optimizer.param_groups[-1]['lr']))

        log.write('loss', float(losses.avg), epoch, test=False)
        return losses.avg

    def test(epoch):
        losses = AverageMeter()

        # switch to evaluation mode
        embedder.eval()
        for batch_idx, (fts, idx) in enumerate(test_loader):
            fts = torch.autograd.Variable(
                fts.cuda()) if use_cuda else torch.autograd.Variable(fts)
            outputs = embedder(fts)
            loss = test_criterion(fts, outputs, idx)
            losses.update(loss.data, len(idx))

            if batch_idx % log_interval == 0:
                print('Test Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\n'.format(
                          epoch, (batch_idx + 1) * len(idx),
                          len(test_loader.sampler), float(losses.val),
                          float(losses.avg)))

        log.write('loss', float(losses.avg), epoch, test=True)
        return losses.avg

    # train network until scheduler reduces learning rate to threshold value
    lr_threshold = 1e-5
    epoch = 1

    best_loss = float('inf')
    while optimizer.param_groups[-1]['lr'] >= lr_threshold:
        train(epoch)
        testloss = test(epoch)
        scheduler.step(testloss)
        if testloss < best_loss:
            best_loss = testloss
            torch.save(
                {
                    'epoch': epoch,
                    'best_loss': best_loss,
                    'state_dict': embedder.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                }, os.path.join(outpath_model, exp_name + '.pth.tar'))
        epoch += 1

    print('Finished training embedder with best loss: {}'.format(best_loss))

    return os.path.join(outpath_model, exp_name + '.pth.tar')
Exemple #3
0
def train(net,
          feature,
          image_id,
          old_embedding,
          target_embedding,
          idx_modified,
          idx_old_neighbors,
          idx_new_neighbors,
          lr=1e-3,
          experiment_id=None,
          socket_id=None,
          scale_func=None):
    global cycle, previously_modified
    cycle += 1
    # log and saving options
    exp_name = 'MapNet'

    if experiment_id is not None:
        exp_name = experiment_id + '_' + exp_name

    log = TBPlotter(os.path.join('runs/mapping', 'tensorboard', exp_name))
    log.print_logdir()

    outpath_config = os.path.join('runs/mapping', exp_name, 'configs')
    if not os.path.isdir(outpath_config):
        os.makedirs(outpath_config)
    outpath_embedding = os.path.join('runs/mapping', exp_name, 'embeddings')
    if not os.path.isdir(outpath_embedding):
        os.makedirs(outpath_embedding)
    outpath_feature = os.path.join('runs/mapping', exp_name, 'features')
    if not os.path.isdir(outpath_feature):
        os.makedirs(outpath_feature)
    outpath_model = os.path.join('runs/mapping', exp_name, 'models')
    if not os.path.isdir(outpath_model):
        os.makedirs(outpath_model)

    # general
    N = len(feature)
    use_cuda = torch.cuda.is_available()
    if not isinstance(old_embedding, torch.Tensor):
        old_embedding = torch.from_numpy(old_embedding.copy())
    if not isinstance(target_embedding, torch.Tensor):
        target_embedding = torch.from_numpy(target_embedding.copy())
    if use_cuda:
        net = net.cuda()
    net.train()

    # find high dimensional neighbors
    idx_high_dim_neighbors = mutual_k_nearest_neighbors(
        feature, idx_modified,
        k=100)  # use the first 100 nn of modified samples

    # ensure there is no overlap between different index groups
    previously_modified = np.setdiff1d(
        previously_modified,
        idx_modified)  # if sample was modified again, allow change
    neighbors = np.unique(
        np.concatenate(
            [idx_old_neighbors, idx_new_neighbors, idx_high_dim_neighbors]))
    neighbors = np.setdiff1d(neighbors, previously_modified)
    space_samples = np.setdiff1d(
        range(N),
        np.concatenate([idx_modified, neighbors, previously_modified]))

    for i, g1 in enumerate(
        [idx_modified, previously_modified, neighbors, space_samples]):
        for j, g2 in enumerate(
            [idx_modified, previously_modified, neighbors, space_samples]):
            if i != j and len(np.intersect1d(g1, g2)) != 0:
                print('groups: {}, {}'.format(i, j))
                print(np.intersect1d(g1, g2))
                raise RuntimeError('Index groups overlap.')

    print('Group Overview:'
          '\n\tModified samples: {}'
          '\n\tPreviously modified samples: {}'
          '\n\tNeighbors samples: {}'
          '\n\tSpace samples: {}'.format(len(idx_modified),
                                         len(previously_modified),
                                         len(neighbors), len(space_samples)))

    optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           patience=5,
                                                           threshold=1e-3,
                                                           verbose=True)

    kl_criterion = TSNELoss(N, use_cuda=use_cuda)
    l2_criterion = torch.nn.MSELoss(reduction='none')  # keep the output fixed

    # define the data loaders
    #   sample_loader: modified samples
    #   neighbor_loader: high-dimensional and previous and new neighbors (these will be shown more often in training process)
    #   space_loader: remaining samples (neither modified nor neighbors, only show them once)
    #   fixpoint_loader: previously modified samples, that should not be moved again

    kwargs = {'num_workers': 8} if use_cuda else {}
    batch_size = 1000

    sample_loader = DataLoader(
        IndexDataset(feature),
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(idx_modified),
        **kwargs)

    neighbor_loader = DataLoader(
        IndexDataset(feature),
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(neighbors),
        **kwargs)

    space_loader = DataLoader(
        IndexDataset(feature),
        batch_size=2 * batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(space_samples),
        drop_last=True,
        **kwargs)

    fixpoint_loader = DataLoader(
        IndexDataset(feature),
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(previously_modified),
        **kwargs)

    # train network until scheduler reduces learning rate to threshold value
    track_l2_loss = ChangeRateLogger(n_track=5, threshold=5e-2)
    stop_criterion = False

    embeddings = {}
    model_states = {}
    epoch = 1
    new_features = feature.copy()
    new_embedding = old_embedding.numpy().copy()

    t_beta = []
    t_train = []
    t_tensorboard = []
    t_save = []
    t_send = []
    t_iter = []

    while not stop_criterion:
        t_iter_start = time.time()

        # compute beta for kl loss
        t_beta_start = time.time()
        kl_criterion._compute_beta(new_features)
        t_beta_end = time.time()
        t_beta.append(t_beta_end - t_beta_start)

        # set up losses
        l2_losses = AverageMeter()
        kl_losses = AverageMeter()
        losses = AverageMeter()

        t_load = []
        t_forward = []
        t_loss = []
        t_backprop = []
        t_update = []
        t_tot = []

        # iterate over fix points (assume N_fixpoints >> N_modified)
        t_train_start = time.time()
        t_load_start = time.time()
        print(len(space_loader))
        for batch_idx, (space_data, space_indices) in enumerate(space_loader):
            t_tot_start = time.time()

            # load data

            sample_data, sample_indices = next(iter(sample_loader))
            neighbor_data, neighbor_indices = next(iter(neighbor_loader))
            if len(fixpoint_loader) == 0:
                fixpoint_indices = torch.Tensor([]).type_as(sample_indices)
                data = torch.cat([space_data, sample_data, neighbor_data])
            else:
                fixpoint_data, fixpoint_indices = next(iter(fixpoint_loader))
                data = torch.cat(
                    [space_data, sample_data, neighbor_data, fixpoint_data])
            indices = torch.cat([
                space_indices, sample_indices, neighbor_indices,
                fixpoint_indices
            ])
            input = torch.autograd.Variable(
                data.cuda()) if use_cuda else torch.autograd.Variable(data)

            t_load_end = time.time()
            t_load.append(t_load_end - t_load_start)

            # compute forward
            t_forward_start = time.time()

            fts_mod = net.mapping(input)
            emb_mod = net.embedder(torch.nn.functional.relu(fts_mod))

            t_forward_end = time.time()
            t_forward.append(t_forward_end - t_forward_start)

            # compute losses

            t_loss_start = time.time()

            kl_loss = kl_criterion(fts_mod, emb_mod, indices)
            kl_losses.update(kl_loss.data, len(data))

            idx_l2_fixed = torch.cat(
                [space_indices, sample_indices, fixpoint_indices])
            l2_loss = torch.mean(l2_criterion(
                emb_mod[:len(idx_l2_fixed)],
                target_embedding[idx_l2_fixed].type_as(emb_mod)),
                                 dim=1)
            # weight loss of space samples equally to all modified samples
            n_tot = len(l2_loss)
            n_space, n_sample, n_fixpoint = len(space_indices), len(
                sample_indices), len(fixpoint_indices)
            l2_loss = 0.8 * torch.mean(l2_loss[:n_space]) + 0.2 * torch.mean(
                l2_loss[n_space:])
            assert n_space + n_sample + n_fixpoint == n_tot

            l2_losses.update(l2_loss.data, len(idx_l2_fixed))

            loss = 0.6 * l2_loss + 0.4 * kl_loss.type_as(l2_loss)
            losses.update(loss.data, len(data))

            t_loss_end = time.time()
            t_loss.append(t_loss_end - t_loss_start)
            # backprop

            t_backprop_start = time.time()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t_backprop_end = time.time()
            t_backprop.append(t_backprop_end - t_backprop_start)

            # update

            t_update_start = time.time()

            # update current embedding
            new_embedding[indices] = emb_mod.data.cpu().numpy()

            t_update_end = time.time()
            t_update.append(t_update_end - t_update_start)

            if epoch > 5 and batch_idx >= 2 * len(neighbor_loader):
                print('\tend epoch after {} random fix point samples'.format(
                    (batch_idx + 1) * space_loader.batch_size))
                break

            t_tot_end = time.time()
            t_tot.append(t_tot_end - t_tot_start)

            t_load_start = time.time()

        print('Times:'
              '\n\tLoader: {})'
              '\n\tForward: {})'
              '\n\tLoss: {})'
              '\n\tBackprop: {})'
              '\n\tUpdate: {})'
              '\n\tTotal: {})'.format(
                  np.mean(t_load),
                  np.mean(t_forward),
                  np.mean(t_loss),
                  np.mean(t_backprop),
                  np.mean(t_update),
                  np.mean(t_tot),
              ))

        t_train_end = time.time()
        t_train.append(t_train_end - t_train_start)

        t_tensorboard_start = time.time()
        scheduler.step(losses.avg)
        log.write('l2_loss', float(l2_losses.avg), epoch, test=False)
        log.write('kl_loss', float(kl_losses.avg), epoch, test=False)
        log.write('loss', float(losses.avg), epoch, test=False)
        t_tensorboard_end = time.time()
        t_tensorboard.append(t_tensorboard_end - t_tensorboard_start)

        t_save_start = time.time()

        model_states[epoch] = {
            'epoch': epoch,
            'loss': losses.avg,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }
        embeddings[epoch] = new_embedding

        t_save_end = time.time()
        t_save.append(t_save_end - t_save_start)

        print('Train Epoch: {}\t'
              'Loss: {:.4f}\t'
              'KL Loss: {:.4f}\t'
              'LR: {:.6f}'.format(epoch, float(losses.avg),
                                  float(kl_losses.avg),
                                  optimizer.param_groups[-1]['lr']))

        t_send_start = time.time()

        # send to server
        if socket_id is not None:
            position = new_embedding if scale_func is None else scale_func(
                new_embedding)
            nodes = make_nodes(position=position, index=True)
            send_payload(nodes, socket_id)

        t_send_end = time.time()
        t_send.append(t_send_end - t_send_start)

        epoch += 1
        stop_criterion = track_l2_loss.add_value(l2_losses.avg)

        t_iter_end = time.time()
        t_iter.append(t_iter_end - t_iter_start)

    print('Times:'
          '\n\tBeta: {})'
          '\n\tTrain: {})'
          '\n\tTensorboard: {})'
          '\n\tSave: {})'
          '\n\tSend: {})'
          '\n\tIteration: {})'.format(
              np.mean(t_beta),
              np.mean(t_train),
              np.mean(t_tensorboard),
              np.mean(t_save),
              np.mean(t_send),
              np.mean(t_iter),
          ))

    print('Training details: '
          '\n\tMean: {}'
          '\n\tMax: {} ({})'
          '\n\tMin: {} ({})'.format(np.mean(t_train), np.max(t_train),
                                    np.argmax(t_train), np.min(t_train),
                                    np.argmin(t_train)))

    previously_modified = np.append(previously_modified, idx_modified)

    # compute new features
    new_features = get_feature(net.mapping, feature)

    print('Save output files...')
    # write output files for the cycle
    outfile_config = os.path.join(outpath_config,
                                  'cycle_{:03d}_config.pkl'.format(cycle))
    outfile_embedding = os.path.join(
        outpath_embedding, 'cycle_{:03d}_embeddings.hdf5'.format(cycle))
    outfile_feature = os.path.join(outpath_feature,
                                   'cycle_{:03d}_feature.hdf5'.format(cycle))
    outfile_model_states = os.path.join(
        outpath_model, 'cycle_{:03d}_models.pth.tar'.format(cycle))

    with h5py.File(outfile_embedding, 'w') as f:
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)
        for epoch in embeddings.keys():
            data = embeddings[epoch]
            f.create_dataset(name='epoch_{:04d}'.format(epoch),
                             shape=data.shape,
                             dtype=data.dtype,
                             data=data)

    with h5py.File(outfile_feature, 'w') as f:
        f.create_dataset(name='feature',
                         shape=new_features.shape,
                         dtype=new_features.dtype,
                         data=new_features)
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)

    torch.save(model_states, outfile_model_states)

    # write config file
    config_dict = {
        'idx_modified': idx_modified,
        'idx_old_neighbors': idx_old_neighbors,
        'idx_new_neighbors': idx_new_neighbors,
        'idx_high_dim_neighbors': idx_high_dim_neighbors
    }
    with open(outfile_config, 'w') as f:
        pickle.dump(config_dict, f)
    print('Done.')

    print('Finished training.')
Exemple #4
0
def train_multiclass(train_file, test_file, stat_file,
                     model='mobilenet_v2',
                     classes=('artist_name', 'genre', 'style', 'technique', 'century'),
                     label_file='_user_labels.pkl',
                     im_path='/export/home/kschwarz/Documents/Data/Wikiart_artist49_images',
                     chkpt=None, weight_file=None,
                     triplet_selector='semihard', margin=0.2,
                     labels_per_class=4, samples_per_label=4,
                     use_gpu=True, device=0,
                     epochs=100, batch_size=32, lr=1e-4, momentum=0.9,
                     log_interval=10, log_dir='runs',
                     exp_name=None, seed=123):
    argvars = locals().copy()
    torch.manual_seed(seed)

    # LOAD DATASET
    with open(stat_file, 'r') as f:
        data = pickle.load(f)
        mean, std = data['mean'], data['std']
        mean = [float(m) for m in mean]
        std = [float(s) for s in std]
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ToTensor(),
                normalize,
    ])

    if model.lower() == 'inception_v3':            # change input size to 299
        train_transform.transforms[0].size = (299, 299)
        val_transform.transforms[0].size = (299, 299)
    trainset = create_trainset(train_file, label_file, im_path, train_transform, classes)
    for c in classes:
        if len(trainset.labels_to_ints[c]) < labels_per_class:
            print('less labels in class {} than labels_per_class, use all available labels ({})'
                  .format(c, len(trainset.labels_to_ints[c])))
    valset = create_valset(test_file, im_path, val_transform, trainset.labels_to_ints)
    # PARAMETERS
    use_cuda = use_gpu and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(device)
        torch.cuda.manual_seed_all(seed)

    if model.lower() not in ['squeezenet', 'mobilenet_v1', 'mobilenet_v2', 'vgg16_bn', 'inception_v3', 'alexnet']:
        assert False, 'Unknown model {}\n\t+ Choose from: ' \
                      '[sqeezenet, mobilenet_v1, mobilenet_v2, vgg16_bn, inception_v3, alexnet].'.format(model)
    elif model.lower() == 'mobilenet_v1':
        bodynet = mobilenet_v1(pretrained=weight_file is None)
    elif model.lower() == 'mobilenet_v2':
        bodynet = mobilenet_v2(pretrained=weight_file is None)
    elif model.lower() == 'vgg16_bn':
        bodynet = vgg16_bn(pretrained=weight_file is None)
    elif model.lower() == 'inception_v3':
        bodynet = inception_v3(pretrained=weight_file is None)
    elif model.lower() == 'alexnet':
        bodynet = alexnet(pretrained=weight_file is None)
    else:       # squeezenet
        bodynet = squeezenet(pretrained=weight_file is None)

    # Load weights for the body network
    if weight_file is not None:
        print("=> loading weights from '{}'".format(weight_file))
        pretrained_dict = torch.load(weight_file, map_location=lambda storage, loc: storage)['state_dict']
        state_dict = bodynet.state_dict()
        pretrained_dict = {k.replace('bodynet.', ''): v for k, v in pretrained_dict.items()         # in case of multilabel weight file
                           if (k.replace('bodynet.', '') in state_dict.keys() and v.shape == state_dict[k.replace('bodynet.', '')].shape)}  # number of classes might have changed
        # check which weights will be transferred
        if not pretrained_dict == state_dict:  # some changes were made
            for k in set(state_dict.keys() + pretrained_dict.keys()):
                if k in state_dict.keys() and k not in pretrained_dict.keys():
                    print('\tWeights for "{}" were not found in weight file.'.format(k))
                elif k in pretrained_dict.keys() and k not in state_dict.keys():
                    print('\tWeights for "{}" were are not part of the used model.'.format(k))
                elif state_dict[k].shape != pretrained_dict[k].shape:
                    print('\tShapes of "{}" are different in model ({}) and weight file ({}).'.
                          format(k, state_dict[k].shape, pretrained_dict[k].shape))
                else:  # everything is good
                    pass

        state_dict.update(pretrained_dict)
        bodynet.load_state_dict(state_dict)

    net = MetricNet(bodynet, len(classes))

    n_parameters = sum([p.data.nelement() for p in net.parameters() if p.requires_grad])
    if use_cuda:
        net = net.cuda()
    print('Using {}\n\t+ Number of params: {}'.format(str(net).split('(', 1)[0], n_parameters))

    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    # tensorboard summary writer
    timestamp = time.strftime('%m-%d-%H-%M')
    expname = timestamp + '_' + str(net).split('(', 1)[0]
    if exp_name is not None:
        expname = expname + '_' + exp_name
    log = TBPlotter(os.path.join(log_dir, 'tensorboard', expname))
    log.print_logdir()

    # allow auto-tuner to find best algorithm for the hardware
    cudnn.benchmark = True

    with open(label_file, 'rb') as f:
        labels = pickle.load(f)['labels']
        n_labeled = '\t'.join([str(Counter(l).items()) for l in labels.transpose()])

    write_config(argvars, os.path.join(log_dir, expname), extras={'n_labeled': n_labeled})


    # ININTIALIZE TRAINING
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, threshold=1e-1, verbose=True)

    if triplet_selector.lower() not in ['random', 'semihard', 'hardest', 'mixed', 'khardest']:
        assert False, 'Unknown option {} for triplet selector. Choose from "random", "semihard", "hardest" or "mixed"' \
                      '.'.format(triplet_selector)
    elif triplet_selector.lower() == 'random':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=RandomNegativeTripletSelector(margin, cpu=not use_cuda))
    elif triplet_selector.lower() == 'semihard' or triplet_selector.lower() == 'mixed':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=SemihardNegativeTripletSelector(margin, cpu=not use_cuda))
    elif triplet_selector.lower() == 'khardest':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=KHardestNegativeTripletSelector(margin, k=3, cpu=not use_cuda))
    else:
        criterion = TripletLoss(margin=margin,
                                triplet_selector=HardestNegativeTripletSelector(margin, cpu=not use_cuda))
    if use_cuda:
        criterion = criterion.cuda()

    kwargs = {'num_workers': 4} if use_cuda else {}
    multilabel_train = np.stack([trainset.df[c].values for c in classes]).transpose()
    train_batch_sampler = BalancedBatchSamplerMulticlass(multilabel_train, n_label=labels_per_class,
                                                         n_per_label=samples_per_label, ignore_label=None)
    trainloader = DataLoader(trainset, batch_sampler=train_batch_sampler, **kwargs)
    multilabel_val = np.stack([valset.df[c].values for c in classes]).transpose()
    val_batch_sampler = BalancedBatchSamplerMulticlass(multilabel_val, n_label=labels_per_class,
                                                       n_per_label=samples_per_label, ignore_label=None)
    valloader = DataLoader(valset, batch_sampler=val_batch_sampler, **kwargs)

    # optionally resume from a checkpoint
    start_epoch = 1
    if chkpt is not None:
        if os.path.isfile(chkpt):
            print("=> loading checkpoint '{}'".format(chkpt))
            checkpoint = torch.load(chkpt, map_location=lambda storage, loc: storage)
            start_epoch = checkpoint['epoch']
            best_acc_score = checkpoint['best_acc_score']
            best_acc = checkpoint['acc']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(chkpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(chkpt))

    def train(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)           # do not compute ap pairs for concealed classes
            gtes.update(gte.data, target[0].size(0))
            distances_ap.update(dist_ap.data, target[0].size(0))
            distances_an.update(dist_an.data, target[0].size(0))
            losses.update(loss.data[0], target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'GTE: {:.2f}% ({:.2f}%)\t'
                      'Non-zero Triplets: {:d} ({:d})'.format(
                    epoch, batch_idx * len(target[0]), len(trainloader) * len(target[0]),
                    float(losses.val), float(losses.avg),
                    float(gtes.val) * 100., float(gtes.avg) * 100.,
                    int(non_zero_triplets.val), int(non_zero_triplets.avg)))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('gte', float(gtes.avg), epoch, test=False)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=False)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=False)
        log.write('dist_an', float(distances_an.avg), epoch, test=False)

    def test(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]
            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)
            gtes.update(gte.data.cpu(), target[0].size(0))
            distances_ap.update(dist_ap.data.cpu(), target[0].size(0))
            distances_an.update(dist_an.data.cpu(), target[0].size(0))
            losses.update(loss.data[0].cpu(), target[0].size(0))

        print('\nVal set: Average loss: {:.4f} Average GTE {:.2f}%, '
              'Average non-zero triplets: {:d} LR: {:.6f}'.format(float(losses.avg), float(gtes.avg) * 100.,
                                                       int(non_zero_triplets.avg),
                                                                  optimizer.param_groups[-1]['lr']))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('gte', float(gtes.avg), epoch, test=True)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=True)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=True)
        log.write('dist_an', float(distances_an.avg), epoch, test=True)
        return losses.avg, 1 - gtes.avg

    if start_epoch == 1:         # compute baseline:
        _, best_acc = test(epoch=0)
    else:       # checkpoint was loaded
        best_acc = best_acc

    for epoch in range(start_epoch, epochs + 1):
        if triplet_selector.lower() == 'mixed' and epoch == 26:
            criterion.triplet_selector = HardestNegativeTripletSelector(margin, cpu=not use_cuda)
            print('Changed negative selection from semihard to hardest.')
        # train for one epoch
        train(epoch)
        # evaluate on validation set
        val_loss, val_acc = test(epoch)
        scheduler.step(val_loss)

        # remember best acc and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'best_acc': best_acc,
        }, is_best, expname, directory=log_dir)

        if optimizer.param_groups[-1]['lr'] < 1e-5:
            print('Learning rate reached minimum threshold. End training.')
            break

    # report best values
    best = torch.load(os.path.join(log_dir, expname + '_model_best.pth.tar'), map_location=lambda storage, loc: storage)
    print('Finished training after epoch {}:\n\tbest acc score: {}'
          .format(best['epoch'], best['acc']))
    print('Best model mean accuracy: {}'.format(best_acc))
def train(net,
          feature,
          image_id,
          old_embedding,
          target_embedding,
          idx_modified,
          idx_old_neighbors,
          idx_new_neighbors,
          idx_negatives,
          lr=1e-3,
          experiment_id=None,
          socket_id=None,
          scale_func=None,
          categories=None,
          label=None):
    global cycle, previously_modified
    cycle += 1
    # log and saving options
    exp_name = 'MapNet'

    if experiment_id is not None:
        exp_name = experiment_id + '_' + exp_name

    log = TBPlotter(os.path.join('runs/mapping', 'tensorboard', exp_name))
    log.print_logdir()

    outpath_config = os.path.join('runs/mapping', exp_name, 'configs')
    if not os.path.isdir(outpath_config):
        os.makedirs(outpath_config)
    outpath_embedding = os.path.join('runs/mapping', exp_name, 'embeddings')
    if not os.path.isdir(outpath_embedding):
        os.makedirs(outpath_embedding)
    outpath_feature = os.path.join('runs/mapping', exp_name, 'features')
    if not os.path.isdir(outpath_feature):
        os.makedirs(outpath_feature)
    outpath_model = os.path.join('runs/mapping', exp_name, 'models')
    if not os.path.isdir(outpath_model):
        os.makedirs(outpath_model)

    # general
    N = len(feature)
    use_cuda = torch.cuda.is_available()
    if not isinstance(old_embedding, torch.Tensor):
        old_embedding = torch.from_numpy(old_embedding.copy())
    if not isinstance(target_embedding, torch.Tensor):
        target_embedding = torch.from_numpy(target_embedding.copy())

    if use_cuda:
        net = net.cuda()
    net.train()

    # Set up differend groups of indices
    # each sample belongs to one group exactly, hierarchy is as follows:
    # 1: samples moved by user in this cycle
    # 2: negatives selected through neighbor method
    # 3: new neighborhood
    # 4: samples moved by user in previous cycles
    # 5: old neighborhood
    # 5: high dimensional neighborhood of moved samples
    # 6: fix points / unrelated (remaining) samples

    # # find high dimensional neighbors
    idx_high_dim_neighbors, _ = svm_k_nearest_neighbors(
        feature,
        np.union1d(idx_modified, idx_new_neighbors),
        negative_idcs=idx_negatives,
        k=100
    )  # use the first 100 nn of modified samples          # TODO: Better rely on distance

    # ensure there is no overlap between different index groups
    idx_modified = np.setdiff1d(
        idx_modified, idx_negatives
    )  # just ensure in case negatives have moved accidentially    TODO: BETTER FILTER BEFORE
    idx_new_neighbors = np.setdiff1d(
        idx_new_neighbors, np.concatenate([idx_modified, idx_negatives]))
    idx_previously_modified = np.setdiff1d(
        previously_modified,
        np.concatenate([idx_modified, idx_new_neighbors, idx_negatives]))
    idx_old_neighbors = np.setdiff1d(
        np.concatenate([idx_old_neighbors, idx_high_dim_neighbors]),
        np.concatenate([
            idx_modified, idx_new_neighbors, idx_previously_modified,
            idx_negatives
        ]))
    idx_fix_points = np.setdiff1d(
        range(N),
        np.concatenate([
            idx_modified, idx_new_neighbors, idx_previously_modified,
            idx_old_neighbors, idx_negatives
        ]))

    for i, g1 in enumerate([
            idx_modified, idx_new_neighbors, idx_previously_modified,
            idx_old_neighbors, idx_fix_points, idx_negatives
    ]):
        for j, g2 in enumerate([
                idx_modified, idx_new_neighbors, idx_previously_modified,
                idx_old_neighbors, idx_fix_points, idx_negatives
        ]):
            if i != j and len(np.intersect1d(g1, g2)) != 0:
                print('groups: {}, {}'.format(i, j))
                print(np.intersect1d(g1, g2))
                raise RuntimeError('Index groups overlap.')

    print('Group Overview:'
          '\n\tModified samples: {}'
          '\n\tNegative samples: {}'
          '\n\tNew neighbors: {}'
          '\n\tPreviously modified samples: {}'
          '\n\tOld neighbors: {}'
          '\n\tFix points: {}'.format(len(idx_modified), len(idx_negatives),
                                      len(idx_new_neighbors),
                                      len(idx_previously_modified),
                                      len(idx_old_neighbors),
                                      len(idx_fix_points)))

    # modify label
    label[idx_modified, -1] = 'modified'
    label[idx_negatives, -1] = 'negative'
    label[idx_previously_modified, -1] = 'prev_modified'
    label[idx_new_neighbors, -1] = 'new neighbors'
    label[idx_old_neighbors, -1] = 'old neighbors'
    label[idx_high_dim_neighbors, -1] = 'high dim neighbors'
    label[idx_fix_points, -1] = 'other'

    optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           patience=5,
                                                           threshold=1e-3,
                                                           verbose=True)

    kl_criterion = TSNELoss(N, use_cuda=use_cuda)
    l2_criterion = torch.nn.MSELoss(reduction='none')  # keep the output fixed
    noise_criterion = NormalizedMSE()

    # define the index samplers for data

    batch_size = 500
    max_len = max(
        len(idx_modified) + len(idx_previously_modified), len(idx_negatives),
        len(idx_new_neighbors), len(idx_old_neighbors), len(idx_fix_points))
    if max_len == len(idx_fix_points):
        n_batches = max_len / (batch_size * 2) + 1
    else:
        n_batches = max_len / batch_size + 1

    sampler_modified = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_modified),
        batch_size=batch_size,
        drop_last=False)

    sampler_negatives = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_negatives),
        batch_size=batch_size,
        drop_last=False)

    sampler_new_neighbors = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_new_neighbors),
        batch_size=batch_size,
        drop_last=False)

    sampler_prev_modified = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_previously_modified),
        batch_size=batch_size,
        drop_last=False)

    sampler_old_neighbors = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_old_neighbors),
        batch_size=batch_size,
        drop_last=False)

    sampler_high_dim_neighbors = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_high_dim_neighbors),
        batch_size=batch_size,
        drop_last=False)

    sampler_fixed = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_fix_points),
        batch_size=2 * batch_size,
        drop_last=False)

    # train network until scheduler reduces learning rate to threshold value
    lr_threshold = 1e-5
    track_l2_loss = ChangeRateLogger(n_track=5,
                                     threshold=5e-2,
                                     order='smaller')
    track_noise_reg = ChangeRateLogger(
        n_track=10, threshold=-1,
        order='smaller')  # only consider order --> negative threshold
    stop_criterion = False

    embeddings = {}
    model_states = {}
    cpu_net = copy.deepcopy(net).cpu() if use_cuda else net
    model_states[0] = {
        'epoch': 0,
        'loss': float('inf'),
        'state_dict': cpu_net.state_dict().copy(),
        'optimizer': optimizer.state_dict().copy(),
        'scheduler': scheduler.state_dict().copy()
    }
    embeddings[0] = old_embedding.numpy().copy()

    epoch = 1
    new_features = feature.copy()
    new_embedding = old_embedding.numpy().copy()

    t_beta = []
    t_train = []
    t_tensorboard = []
    t_save = []
    t_send = []
    t_iter = []

    tensor_feature = torch.from_numpy(feature)
    norms = torch.norm(tensor_feature, p=2, dim=1)
    feature_norm = torch.mean(norms)
    norm_margin = norms.std()
    norm_criterion = SoftNormLoss(norm_value=feature_norm, margin=norm_margin)
    # distance_criterion = NormalizedDistanceLoss()
    # distance_criterion = ContrastiveNormalizedDistanceLoss(margin=0.2 * feature_norm)
    triplet_margin = feature_norm
    triplet_selector = SemihardNegativeTripletSelector(
        margin=triplet_margin,
        cpu=False,
        preselect_index_positives=10,
        preselect_index_negatives=1,
        selection='random')
    distance_criterion = TripletLoss(margin=triplet_margin,
                                     triplet_selector=triplet_selector)
    negative_triplet_collector = []

    del norms
    while not stop_criterion:
        # if epoch < 30:           # do not use dropout at first
        #     net.eval()
        # else:
        net.train()

        t_iter_start = time.time()

        # compute beta for kl loss
        t_beta_start = time.time()
        kl_criterion._compute_beta(new_features)
        t_beta_end = time.time()
        t_beta.append(t_beta_end - t_beta_start)

        # set up losses
        l2_losses = AverageMeter()
        kl_losses = AverageMeter()
        distance_losses = AverageMeter()
        noise_regularization = AverageMeter()
        feature_norm = AverageMeter()
        norm_losses = AverageMeter()
        weight_regularization = AverageMeter()
        losses = AverageMeter()

        t_load = []
        t_forward = []
        t_loss = []
        t_backprop = []
        t_update = []
        t_tot = []

        # iterate over fix points (assume N_fixpoints >> N_modified)
        t_train_start = time.time()
        t_load_start = time.time()
        batch_loaders = []
        for smplr in [
                sampler_modified, sampler_negatives, sampler_new_neighbors,
                sampler_prev_modified, sampler_old_neighbors, sampler_fixed,
                sampler_high_dim_neighbors
        ]:
            batches = list(smplr)
            if len(batches) == 0:
                batches = [[] for i in range(n_batches)]
            while len(batches) < n_batches:
                to = min(n_batches - len(batches), len(batches))
                batches.extend(list(smplr)[:to])
            batch_loaders.append(batches)

        for batch_idx in range(n_batches):
            t_tot_start = time.time()

            moved_indices = batch_loaders[0][batch_idx]
            negatives_indices = batch_loaders[1][batch_idx]
            new_neigh_indices = batch_loaders[2][batch_idx]
            prev_moved_indices = batch_loaders[3][batch_idx]
            old_neigh_indices = batch_loaders[4][batch_idx]
            fixed_indices = batch_loaders[5][batch_idx]
            high_neigh_indices = batch_loaders[6][batch_idx]
            n_moved, n_neg, n_new, n_prev, n_old, n_fixed, n_high = (
                len(moved_indices), len(negatives_indices),
                len(new_neigh_indices), len(prev_moved_indices),
                len(old_neigh_indices), len(fixed_indices),
                len(high_neigh_indices))

            # load data
            indices = np.concatenate([
                new_neigh_indices, moved_indices, negatives_indices,
                prev_moved_indices, fixed_indices, old_neigh_indices,
                high_neigh_indices
            ]).astype(long)
            if len(indices) < 3 * kl_criterion.perplexity + 2:
                continue
            data = tensor_feature[indices]
            input = torch.autograd.Variable(
                data.cuda()) if use_cuda else torch.autograd.Variable(data)

            t_load_end = time.time()
            t_load.append(t_load_end - t_load_start)

            # compute forward
            t_forward_start = time.time()

            fts_mod = net.mapping(input)
            # fts_mod_noise = net.mapping(input + 0.1 * torch.rand(input.shape).type_as(input))
            fts_mod_noise = net.mapping(input +
                                        torch.rand(input.shape).type_as(input))
            emb_mod = net.embedder(torch.nn.functional.relu(fts_mod))

            t_forward_end = time.time()
            t_forward.append(t_forward_end - t_forward_start)

            # compute losses
            # modified --> KL, L2, Dist
            # new neighborhood --> KL, Dist
            # previously modified --> KL, L2
            # old neighborhood + high dimensional neighborhood --> KL
            # fix point samples --> KL, L2

            t_loss_start = time.time()

            noise_reg = noise_criterion(fts_mod, fts_mod_noise)
            noise_regularization.update(noise_reg.data, len(data))

            kl_loss = kl_criterion(fts_mod, emb_mod, indices)
            kl_losses.update(kl_loss.data, len(data))

            idx_l2_fixed = np.concatenate([
                new_neigh_indices, moved_indices, negatives_indices,
                prev_moved_indices, fixed_indices
            ]).astype(long)
            l2_loss = torch.mean(l2_criterion(
                emb_mod[:n_new + n_moved + n_neg + n_prev + n_fixed],
                target_embedding[idx_l2_fixed].type_as(emb_mod)),
                                 dim=1)
            # weigh loss of space samples equally to all modified samples
            l2_loss = 0.5 * torch.mean(l2_loss[:n_new + n_moved + n_neg + n_prev]) + \
                      0.5 * torch.mean(l2_loss[n_new + n_moved + n_neg + n_prev:])

            l2_losses.update(l2_loss.data, len(idx_l2_fixed))

            if epoch < 0:
                distance_loss = torch.tensor(0.)
            else:
                # distance_loss = distance_criterion(fts_mod[:n_new + n_moved])
                distance_loss_input = fts_mod[:-(n_old + n_high)] if (
                    n_old + n_high) > 0 else fts_mod
                distance_loss_target = torch.cat([
                    torch.ones(n_new + n_moved),
                    torch.zeros(n_neg + n_prev + n_fixed)
                ])
                distance_loss_weights = torch.cat([
                    torch.ones(n_new + n_moved + n_neg + n_prev),
                    0.5 * torch.ones(n_fixed)
                ])
                # also use high dimensional nn
                # distance_loss_weights = torch.cat([torch.ones(n_new+n_moved+n_neg+n_prev+n_fixed), 0.5*torch.ones(len(high_dim_nn))])

                # if len(high_dim_nn) > 0:
                #     distance_loss_input = torch.cat([distance_loss_input, fts_mod[high_dim_nn]])
                #     distance_loss_target = torch.cat([distance_loss_target, torch.ones(len(high_dim_nn))])
                if n_neg > 0:
                    selected_negatives = {
                        1: np.arange(n_new + n_moved, n_new + n_moved + n_neg)
                    }
                else:
                    selected_negatives = None

                distance_loss, negative_triplets = distance_criterion(
                    distance_loss_input,
                    distance_loss_target,
                    concealed_classes=[0],
                    weights=distance_loss_weights,
                    selected_negatives=selected_negatives)
                if negative_triplets is not None:
                    negative_triplets = np.unique(negative_triplets.numpy())
                    negative_triplets = indices[:-(n_old +
                                                   n_high)][negative_triplets]
                    negative_triplet_collector.extend(negative_triplets)
                distance_loss_noise, _ = distance_criterion(
                    distance_loss_input +
                    torch.rand(distance_loss_input.shape).type_as(
                        distance_loss_input),
                    distance_loss_target,
                    concealed_classes=[0])
                distance_loss = 0.5 * distance_loss + 0.5 * distance_loss_noise

            distance_losses.update(distance_loss.data, n_new + n_moved)

            # norm_loss = norm_criterion(torch.mean(fts_mod.norm(p=2, dim=1)))
            # norm_losses.update(norm_loss.data, len(data))

            weight_reg = torch.autograd.Variable(
                torch.tensor(0.)).type_as(l2_loss)
            for param in net.mapping.parameters():
                weight_reg += param.norm(1)
            weight_regularization.update(weight_reg, len(data))

            loss = 1 * distance_loss.type_as(l2_loss) + 5 * l2_loss + 10 * kl_loss.type_as(l2_loss) + \
                   1e-5 * weight_reg.type_as(l2_loss) #+ norm_loss.type_as(l2_loss)\ 1e3 * noise_reg.type_as(l2_loss)
            losses.update(loss.data, len(data))

            t_loss_end = time.time()
            t_loss.append(t_loss_end - t_loss_start)

            feature_norm.update(
                torch.mean(fts_mod.norm(p=2, dim=1)).data, len(data))

            # backprop

            t_backprop_start = time.time()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t_backprop_end = time.time()
            t_backprop.append(t_backprop_end - t_backprop_start)

            # update

            t_update_start = time.time()

            # update current embedding
            new_embedding[indices] = emb_mod.data.cpu().numpy()

            t_update_end = time.time()
            t_update.append(t_update_end - t_update_start)

            if epoch > 5 and (batch_idx + 1) * batch_size >= 2000:
                print('\tend epoch after {} random fix point samples'.format(
                    (batch_idx + 1) * batch_size))
                break

            t_tot_end = time.time()
            t_tot.append(t_tot_end - t_tot_start)

            t_load_start = time.time()

        # print('Times:'
        #       '\n\tLoader: {})'
        #       '\n\tForward: {})'
        #       '\n\tLoss: {})'
        #       '\n\tBackprop: {})'
        #       '\n\tUpdate: {})'
        #       '\n\tTotal: {})'.format(
        #     np.mean(t_load),
        #     np.mean(t_forward),
        #     np.mean(t_loss),
        #     np.mean(t_backprop),
        #     np.mean(t_update),
        #     np.mean(t_tot),
        # ))

        t_train_end = time.time()
        t_train.append(t_train_end - t_train_start)

        t_tensorboard_start = time.time()
        scheduler.step(losses.avg)
        label[np.unique(negative_triplet_collector), -1] = 'negative triplet'
        log.write('l2_loss', float(l2_losses.avg), epoch, test=False)
        log.write('distance_loss',
                  float(distance_losses.avg),
                  epoch,
                  test=False)
        log.write('kl_loss', float(kl_losses.avg), epoch, test=False)
        log.write('noise_regularization',
                  float(noise_regularization.avg),
                  epoch,
                  test=False)
        log.write('feature_norm', float(feature_norm.avg), epoch, test=False)
        log.write('norm_loss', float(norm_losses.avg), epoch, test=False)
        log.write('weight_reg',
                  float(weight_regularization.avg),
                  epoch,
                  test=False)
        log.write('loss', float(losses.avg), epoch, test=False)
        t_tensorboard_end = time.time()
        t_tensorboard.append(t_tensorboard_end - t_tensorboard_start)

        t_save_start = time.time()

        cpu_net = copy.deepcopy(net).cpu() if use_cuda else net

        model_states[epoch] = {
            'epoch': epoch,
            'loss': losses.avg.cpu(),
            'state_dict': cpu_net.state_dict().copy(),
            'optimizer': optimizer.state_dict().copy(),
            'scheduler': scheduler.state_dict().copy()
        }
        embeddings[epoch] = new_embedding

        t_save_end = time.time()
        t_save.append(t_save_end - t_save_start)

        print('Train Epoch: {}\t'
              'Loss: {:.4f}\t'
              'L2 Loss: {:.4f}\t'
              'Distance Loss: {:.4f}\t'
              'KL Loss: {:.4f}\t'
              'Noise Regularization: {:.4f}\t'
              'Weight Regularization: {:.4f}\t'
              'LR: {:.6f}'.format(epoch, float(losses.avg),
                                  float(5 * l2_losses.avg),
                                  float(0.5 * distance_losses.avg),
                                  float(10 * kl_losses.avg),
                                  float(noise_regularization.avg),
                                  float(1e-5 * weight_regularization.avg),
                                  optimizer.param_groups[-1]['lr']))

        t_send_start = time.time()

        # send to server
        if socket_id is not None:
            position = new_embedding if scale_func is None else scale_func(
                new_embedding)
            nodes = make_nodes(position=position, index=True, label=label)
            send_payload(nodes, socket_id, categories=categories)

        t_send_end = time.time()
        t_send.append(t_send_end - t_send_start)

        epoch += 1
        l2_stop_criterion = track_l2_loss.add_value(l2_losses.avg)
        epoch_stop_criterion = epoch > 150
        regularization_stop_criterion = False  #track_noise_reg.add_value(noise_regularization.avg)
        lr_stop_criterion = optimizer.param_groups[-1]['lr'] < lr_threshold
        stop_criterion = any([
            l2_stop_criterion, regularization_stop_criterion,
            lr_stop_criterion, epoch_stop_criterion
        ])

        t_iter_end = time.time()
        t_iter.append(t_iter_end - t_iter_start)

    print('Times:'
          '\n\tBeta: {})'
          '\n\tTrain: {})'
          '\n\tTensorboard: {})'
          '\n\tSave: {})'
          '\n\tSend: {})'
          '\n\tIteration: {})'.format(
              np.mean(t_beta),
              np.mean(t_train),
              np.mean(t_tensorboard),
              np.mean(t_save),
              np.mean(t_send),
              np.mean(t_iter),
          ))

    print('Training details: '
          '\n\tMean: {}'
          '\n\tMax: {} ({})'
          '\n\tMin: {} ({})'.format(np.mean(t_train), np.max(t_train),
                                    np.argmax(t_train), np.min(t_train),
                                    np.argmin(t_train)))

    previously_modified = np.append(previously_modified, idx_modified)

    # compute new features
    new_features = get_feature(net.mapping, feature)

    # print('Save output files...')
    # write output files for the cycle
    outfile_config = os.path.join(outpath_config,
                                  'cycle_{:03d}_config.pkl'.format(cycle))
    outfile_embedding = os.path.join(
        outpath_embedding, 'cycle_{:03d}_embeddings.hdf5'.format(cycle))
    outfile_feature = os.path.join(outpath_feature,
                                   'cycle_{:03d}_feature.hdf5'.format(cycle))
    outfile_model_states = os.path.join(
        outpath_model, 'cycle_{:03d}_models.pth.tar'.format(cycle))

    with h5py.File(outfile_embedding, 'w') as f:
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)
        for epoch in embeddings.keys():
            data = embeddings[epoch]
            f.create_dataset(name='epoch_{:04d}'.format(epoch),
                             shape=data.shape,
                             dtype=data.dtype,
                             data=data)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_embedding)))

    with h5py.File(outfile_feature, 'w') as f:
        f.create_dataset(name='feature',
                         shape=new_features.shape,
                         dtype=new_features.dtype,
                         data=new_features)
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_feature)))

    torch.save(model_states, outfile_model_states)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_model_states)))

    # write config file
    config_dict = {
        'idx_modified': idx_modified,
        'idx_old_neighbors': idx_old_neighbors,
        'idx_new_neighbors': idx_new_neighbors,
        'idx_high_dim_neighbors': idx_high_dim_neighbors
    }
    with open(outfile_config, 'w') as f:
        pickle.dump(config_dict, f)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_config)))

    print('Done.')

    print('Finished training.')
    return new_embedding
Exemple #6
0
def train_multiclass(
        train_file,
        test_file,
        stat_file,
        model='mobilenet_v2',
        classes=('artist_name', 'genre', 'style', 'technique', 'century'),
        im_path='/export/home/kschwarz/Documents/Data/Wikiart_artist49_images',
        label_file='_user_labels.pkl',
        chkpt=None,
        weight_file=None,
        use_gpu=True,
        device=0,
        epochs=100,
        batch_size=32,
        lr=1e-4,
        momentum=0.9,
        log_interval=10,
        log_dir='runs',
        exp_name=None,
        seed=123):
    argvars = locals().copy()
    torch.manual_seed(seed)

    # LOAD DATASET
    with open(stat_file, 'r') as f:
        data = pickle.load(f)
        mean, std = data['mean'], data['std']
        mean = [float(m) for m in mean]
        std = [float(s) for s in std]
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    if model.lower() == 'inception_v3':  # change input size to 299
        train_transform.transforms[0].size = (299, 299)
        val_transform.transforms[0].size = (299, 299)
    trainset = create_trainset(train_file, label_file, im_path,
                               train_transform, classes)
    valset = create_valset(test_file, im_path, val_transform,
                           trainset.labels_to_ints)
    num_labels = [len(trainset.labels_to_ints[c]) for c in classes]

    # PARAMETERS
    use_cuda = use_gpu and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(device)
        torch.cuda.manual_seed_all(seed)

    if model.lower() not in [
            'squeezenet', 'mobilenet_v1', 'mobilenet_v2', 'vgg16_bn',
            'inception_v3', 'alexnet'
    ]:
        assert False, 'Unknown model {}\n\t+ Choose from: ' \
                      '[sqeezenet, mobilenet_v1, mobilenet_v2, vgg16_bn, inception_v3, alexnet].'.format(model)
    elif model.lower() == 'mobilenet_v1':
        bodynet = mobilenet_v1(pretrained=weight_file is None)
    elif model.lower() == 'mobilenet_v2':
        bodynet = mobilenet_v2(pretrained=weight_file is None)
    elif model.lower() == 'vgg16_bn':
        bodynet = vgg16_bn(pretrained=weight_file is None)
    elif model.lower() == 'inception_v3':
        bodynet = inception_v3(pretrained=weight_file is None)
    elif model.lower() == 'alexnet':
        bodynet = alexnet(pretrained=weight_file is None)
    else:  # squeezenet
        bodynet = squeezenet(pretrained=weight_file is None)

    # Load weights for the body network
    if weight_file is not None:
        print("=> loading weights from '{}'".format(weight_file))
        pretrained_dict = torch.load(
            weight_file,
            map_location=lambda storage, loc: storage)['state_dict']
        state_dict = bodynet.state_dict()
        pretrained_dict = {
            k.replace('bodynet.', ''): v
            for k, v in
            pretrained_dict.items()  # in case of multilabel weight file
            if (k.replace('bodynet.', '') in state_dict.keys()
                and v.shape == state_dict[k.replace('bodynet.', '')].shape)
        }  # number of classes might have changed
        # check which weights will be transferred
        if not pretrained_dict == state_dict:  # some changes were made
            for k in set(state_dict.keys() + pretrained_dict.keys()):
                if k in state_dict.keys() and k not in pretrained_dict.keys():
                    print('\tWeights for "{}" were not found in weight file.'.
                          format(k))
                elif k in pretrained_dict.keys() and k not in state_dict.keys(
                ):
                    print(
                        '\tWeights for "{}" were are not part of the used model.'
                        .format(k))
                elif state_dict[k].shape != pretrained_dict[k].shape:
                    print(
                        '\tShapes of "{}" are different in model ({}) and weight file ({}).'
                        .format(k, state_dict[k].shape,
                                pretrained_dict[k].shape))
                else:  # everything is good
                    pass

        state_dict.update(pretrained_dict)
        bodynet.load_state_dict(state_dict)

    net = OctopusNet(bodynet, n_labels=num_labels)

    n_parameters = sum(
        [p.data.nelement() for p in net.parameters() if p.requires_grad])
    if use_cuda:
        net = net.cuda()
    print('Using {}\n\t+ Number of params: {}'.format(
        str(net).split('(', 1)[0], n_parameters))

    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    # tensorboard summary writer
    timestamp = time.strftime('%m-%d-%H-%M')
    expname = timestamp + '_' + str(net).split('(', 1)[0]
    if exp_name is not None:
        expname = expname + '_' + exp_name
    log = TBPlotter(os.path.join(log_dir, 'tensorboard', expname))
    log.print_logdir()

    # allow auto-tuner to find best algorithm for the hardware
    cudnn.benchmark = True

    with open(label_file, 'rb') as f:
        labels = pickle.load(f)['labels']
        n_labeled = '\t'.join(
            [str(Counter(l).items()) for l in labels.transpose()])

    write_config(argvars,
                 os.path.join(log_dir, expname),
                 extras={'n_labeled': n_labeled})

    # ININTIALIZE TRAINING
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=10,
                                                     threshold=1e-1,
                                                     verbose=True)
    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion = criterion.cuda()

    kwargs = {'num_workers': 4} if use_cuda else {}
    trainloader = DataLoader(trainset,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)
    valloader = DataLoader(valset,
                           batch_size=batch_size,
                           shuffle=True,
                           **kwargs)

    # optionally resume from a checkpoint
    start_epoch = 1
    if chkpt is not None:
        if os.path.isfile(chkpt):
            print("=> loading checkpoint '{}'".format(chkpt))
            checkpoint = torch.load(chkpt,
                                    map_location=lambda storage, loc: storage)
            start_epoch = checkpoint['epoch']
            best_acc_score = checkpoint['best_acc_score']
            best_acc = checkpoint['acc']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                chkpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(chkpt))

    def train(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            if use_cuda:
                data, target = Variable(
                    data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0]),
                            requires_grad=True).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (t != -1).nonzero().view(-1)
                o, t, p = o[labeled], t[labeled], p[labeled]
                loss += criterion(o, t)
                # measure class accuracy and record loss
                class_acc[i].update(
                    (torch.sum(p == t).type(torch.FloatTensor) /
                     t.size(0)).data)
            accs.update(
                torch.mean(
                    torch.stack(
                        [class_acc[i].val for i in range(len(classes))])),
                target[0].size(0))
            losses.update(loss.data, target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'Acc: {:.2f}% ({:.2f}%)'.format(epoch,
                                                      batch_idx * len(target),
                                                      len(trainloader.dataset),
                                                      float(losses.val),
                                                      float(losses.avg),
                                                      float(accs.val) * 100.,
                                                      float(accs.avg) * 100.))
                print('\t' + '\n\t'.join([
                    '{}: {:.2f}%'.format(classes[i],
                                         float(class_acc[i].val) * 100.)
                    for i in range(len(classes))
                ]))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('acc', float(accs.avg), epoch, test=False)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=False)

    def test(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            if use_cuda:
                data, target = Variable(
                    data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0]),
                            requires_grad=True).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                labeled = (t != -1).nonzero().view(-1)
                loss += criterion(o[labeled], t[labeled])
                # measure class accuracy and record loss
                class_acc[i].update((torch.sum(p[labeled] == t[labeled]).type(
                    torch.FloatTensor) / t[labeled].size(0)).data)
            accs.update(
                torch.mean(
                    torch.stack(
                        [class_acc[i].val for i in range(len(classes))])),
                target[0].size(0))
            losses.update(loss.data, target[0].size(0))

        score = accs.avg - torch.std(
            torch.stack([class_acc[i].avg for i in range(len(classes))])
        ) / accs.avg  # compute mean - std/mean as measure for accuracy
        print(
            '\nVal set: Average loss: {:.4f} Average acc {:.2f}% Acc score {:.2f} LR: {:.6f}'
            .format(float(losses.avg),
                    float(accs.avg) * 100., float(score),
                    optimizer.param_groups[-1]['lr']))
        print('\t' + '\n\t'.join([
            '{}: {:.2f}%'.format(classes[i],
                                 float(class_acc[i].avg) * 100.)
            for i in range(len(classes))
        ]))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('acc', float(accs.avg), epoch, test=True)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=True)
        return losses.avg.cpu().numpy(), float(score), float(
            accs.avg), [float(class_acc[i].avg) for i in range(len(classes))]

    if start_epoch == 1:  # compute baseline:
        _, best_acc_score, best_acc, _ = test(epoch=0)
    else:  # checkpoint was loaded
        best_acc_score = best_acc_score
        best_acc = best_acc

    for epoch in range(start_epoch, epochs + 1):
        # train for one epoch
        train(epoch)
        # evaluate on validation set
        val_loss, val_acc_score, val_acc, val_class_accs = test(epoch)
        scheduler.step(val_loss)

        # remember best acc and save checkpoint
        is_best = val_acc_score > best_acc_score
        best_acc_score = max(val_acc_score, best_acc_score)
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_acc_score': best_acc_score,
                'acc': val_acc,
                'class_acc': {c: a
                              for c, a in zip(classes, val_class_accs)}
            },
            is_best,
            expname,
            directory=log_dir)

        if val_acc > best_acc:
            shutil.copyfile(
                os.path.join(log_dir, expname + '_checkpoint.pth.tar'),
                os.path.join(log_dir,
                             expname + '_model_best_mean_acc.pth.tar'))
        best_acc = max(val_acc, best_acc)

        if optimizer.param_groups[-1]['lr'] < 1e-5:
            print('Learning rate reached minimum threshold. End training.')
            break

    # report best values
    best = torch.load(os.path.join(log_dir, expname + '_model_best.pth.tar'),
                      map_location=lambda storage, loc: storage)
    print(
        'Finished training after epoch {}:\n\tbest acc score: {}\n\tacc: {}\n\t class acc: {}'
        .format(best['epoch'], best['best_acc_score'], best['acc'],
                best['class_acc']))
    print('Best model mean accuracy: {}'.format(best_acc))