示例#1
0
    def train(epoch):
        losses = AverageMeter()
        # if epoch == stop_early_compression:
        #     print('stop early compression')

        # switch to train mode
        embedder.train()
        for batch_idx, (fts, idx) in enumerate(train_loader):
            fts = torch.autograd.Variable(
                fts.cuda()) if use_cuda else torch.autograd.Variable(fts)
            outputs = embedder(fts)
            loss = train_criterion(fts, outputs, idx)

            losses.update(loss.data, len(idx))

            # if epoch <= stop_early_compression:
            #     optimizer.weight_decay = 0
            #     loss += 0.01 * outputs.norm(p=2, dim=1).mean().type_as(loss)

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'LR: {:.6f}'.format(epoch, (batch_idx + 1) * len(idx),
                                          len(train_loader.sampler),
                                          float(losses.val), float(losses.avg),
                                          optimizer.param_groups[-1]['lr']))

        log.write('loss', float(losses.avg), epoch, test=False)
        return losses.avg
示例#2
0
    def test(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            if use_cuda:
                data, target = Variable(
                    data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0]),
                            requires_grad=True).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                labeled = (t != -1).nonzero().view(-1)
                loss += criterion(o[labeled], t[labeled])
                # measure class accuracy and record loss
                class_acc[i].update((torch.sum(p[labeled] == t[labeled]).type(
                    torch.FloatTensor) / t[labeled].size(0)).data)
            accs.update(
                torch.mean(
                    torch.stack(
                        [class_acc[i].val for i in range(len(classes))])),
                target[0].size(0))
            losses.update(loss.data, target[0].size(0))

        score = accs.avg - torch.std(
            torch.stack([class_acc[i].avg for i in range(len(classes))])
        ) / accs.avg  # compute mean - std/mean as measure for accuracy
        print(
            '\nVal set: Average loss: {:.4f} Average acc {:.2f}% Acc score {:.2f} LR: {:.6f}'
            .format(float(losses.avg),
                    float(accs.avg) * 100., float(score),
                    optimizer.param_groups[-1]['lr']))
        print('\t' + '\n\t'.join([
            '{}: {:.2f}%'.format(classes[i],
                                 float(class_acc[i].avg) * 100.)
            for i in range(len(classes))
        ]))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('acc', float(accs.avg), epoch, test=True)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=True)
        return losses.avg.cpu().numpy(), float(score), float(
            accs.avg), [float(class_acc[i].avg) for i in range(len(classes))]
示例#3
0
    def train(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0])).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                # in case of None labels
                mask = t != -1
                if mask.sum() == 0:
                    continue
                o, t, p = o[mask], t[mask], p[mask]
                loss += criterion(o, t)
                # measure class accuracy and record loss
                class_acc[i].update((torch.sum(p == t).type(torch.FloatTensor) / t.size(0)).data)
            accs.update(torch.mean(torch.stack([class_acc[i].val for i in range(len(classes))])), target[0].size(0))
            losses.update(loss.data, target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'Acc: {:.2f}% ({:.2f}%)'.format(
                    epoch, batch_idx * len(target), len(trainloader.dataset),
                    float(losses.val), float(losses.avg),
                           float(accs.val) * 100., float(accs.avg) * 100.))
                print('\t' + '\n\t'.join(['{}: {:.2f}%'.format(classes[i], float(class_acc[i].val) * 100.)
                                          for i in range(len(classes))]))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('acc', float(accs.avg), epoch, test=False)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=False)
示例#4
0
    def test(epoch):
        losses = AverageMeter()

        # switch to evaluation mode
        embedder.eval()
        for batch_idx, (fts, idx) in enumerate(test_loader):
            fts = torch.autograd.Variable(
                fts.cuda()) if use_cuda else torch.autograd.Variable(fts)
            outputs = embedder(fts)
            loss = test_criterion(fts, outputs, idx)
            losses.update(loss.data, len(idx))

            if batch_idx % log_interval == 0:
                print('Test Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\n'.format(
                          epoch, (batch_idx + 1) * len(idx),
                          len(test_loader.sampler), float(losses.val),
                          float(losses.avg)))

        log.write('loss', float(losses.avg), epoch, test=True)
        return losses.avg
示例#5
0
def train(net,
          feature,
          image_id,
          old_embedding,
          target_embedding,
          idx_modified,
          idx_old_neighbors,
          idx_new_neighbors,
          lr=1e-3,
          experiment_id=None,
          socket_id=None,
          scale_func=None):
    global cycle, previously_modified
    cycle += 1
    # log and saving options
    exp_name = 'MapNet'

    if experiment_id is not None:
        exp_name = experiment_id + '_' + exp_name

    log = TBPlotter(os.path.join('runs/mapping', 'tensorboard', exp_name))
    log.print_logdir()

    outpath_config = os.path.join('runs/mapping', exp_name, 'configs')
    if not os.path.isdir(outpath_config):
        os.makedirs(outpath_config)
    outpath_embedding = os.path.join('runs/mapping', exp_name, 'embeddings')
    if not os.path.isdir(outpath_embedding):
        os.makedirs(outpath_embedding)
    outpath_feature = os.path.join('runs/mapping', exp_name, 'features')
    if not os.path.isdir(outpath_feature):
        os.makedirs(outpath_feature)
    outpath_model = os.path.join('runs/mapping', exp_name, 'models')
    if not os.path.isdir(outpath_model):
        os.makedirs(outpath_model)

    # general
    N = len(feature)
    use_cuda = torch.cuda.is_available()
    if not isinstance(old_embedding, torch.Tensor):
        old_embedding = torch.from_numpy(old_embedding.copy())
    if not isinstance(target_embedding, torch.Tensor):
        target_embedding = torch.from_numpy(target_embedding.copy())
    if use_cuda:
        net = net.cuda()
    net.train()

    # find high dimensional neighbors
    idx_high_dim_neighbors = mutual_k_nearest_neighbors(
        feature, idx_modified,
        k=100)  # use the first 100 nn of modified samples

    # ensure there is no overlap between different index groups
    previously_modified = np.setdiff1d(
        previously_modified,
        idx_modified)  # if sample was modified again, allow change
    neighbors = np.unique(
        np.concatenate(
            [idx_old_neighbors, idx_new_neighbors, idx_high_dim_neighbors]))
    neighbors = np.setdiff1d(neighbors, previously_modified)
    space_samples = np.setdiff1d(
        range(N),
        np.concatenate([idx_modified, neighbors, previously_modified]))

    for i, g1 in enumerate(
        [idx_modified, previously_modified, neighbors, space_samples]):
        for j, g2 in enumerate(
            [idx_modified, previously_modified, neighbors, space_samples]):
            if i != j and len(np.intersect1d(g1, g2)) != 0:
                print('groups: {}, {}'.format(i, j))
                print(np.intersect1d(g1, g2))
                raise RuntimeError('Index groups overlap.')

    print('Group Overview:'
          '\n\tModified samples: {}'
          '\n\tPreviously modified samples: {}'
          '\n\tNeighbors samples: {}'
          '\n\tSpace samples: {}'.format(len(idx_modified),
                                         len(previously_modified),
                                         len(neighbors), len(space_samples)))

    optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           patience=5,
                                                           threshold=1e-3,
                                                           verbose=True)

    kl_criterion = TSNELoss(N, use_cuda=use_cuda)
    l2_criterion = torch.nn.MSELoss(reduction='none')  # keep the output fixed

    # define the data loaders
    #   sample_loader: modified samples
    #   neighbor_loader: high-dimensional and previous and new neighbors (these will be shown more often in training process)
    #   space_loader: remaining samples (neither modified nor neighbors, only show them once)
    #   fixpoint_loader: previously modified samples, that should not be moved again

    kwargs = {'num_workers': 8} if use_cuda else {}
    batch_size = 1000

    sample_loader = DataLoader(
        IndexDataset(feature),
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(idx_modified),
        **kwargs)

    neighbor_loader = DataLoader(
        IndexDataset(feature),
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(neighbors),
        **kwargs)

    space_loader = DataLoader(
        IndexDataset(feature),
        batch_size=2 * batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(space_samples),
        drop_last=True,
        **kwargs)

    fixpoint_loader = DataLoader(
        IndexDataset(feature),
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(previously_modified),
        **kwargs)

    # train network until scheduler reduces learning rate to threshold value
    track_l2_loss = ChangeRateLogger(n_track=5, threshold=5e-2)
    stop_criterion = False

    embeddings = {}
    model_states = {}
    epoch = 1
    new_features = feature.copy()
    new_embedding = old_embedding.numpy().copy()

    t_beta = []
    t_train = []
    t_tensorboard = []
    t_save = []
    t_send = []
    t_iter = []

    while not stop_criterion:
        t_iter_start = time.time()

        # compute beta for kl loss
        t_beta_start = time.time()
        kl_criterion._compute_beta(new_features)
        t_beta_end = time.time()
        t_beta.append(t_beta_end - t_beta_start)

        # set up losses
        l2_losses = AverageMeter()
        kl_losses = AverageMeter()
        losses = AverageMeter()

        t_load = []
        t_forward = []
        t_loss = []
        t_backprop = []
        t_update = []
        t_tot = []

        # iterate over fix points (assume N_fixpoints >> N_modified)
        t_train_start = time.time()
        t_load_start = time.time()
        print(len(space_loader))
        for batch_idx, (space_data, space_indices) in enumerate(space_loader):
            t_tot_start = time.time()

            # load data

            sample_data, sample_indices = next(iter(sample_loader))
            neighbor_data, neighbor_indices = next(iter(neighbor_loader))
            if len(fixpoint_loader) == 0:
                fixpoint_indices = torch.Tensor([]).type_as(sample_indices)
                data = torch.cat([space_data, sample_data, neighbor_data])
            else:
                fixpoint_data, fixpoint_indices = next(iter(fixpoint_loader))
                data = torch.cat(
                    [space_data, sample_data, neighbor_data, fixpoint_data])
            indices = torch.cat([
                space_indices, sample_indices, neighbor_indices,
                fixpoint_indices
            ])
            input = torch.autograd.Variable(
                data.cuda()) if use_cuda else torch.autograd.Variable(data)

            t_load_end = time.time()
            t_load.append(t_load_end - t_load_start)

            # compute forward
            t_forward_start = time.time()

            fts_mod = net.mapping(input)
            emb_mod = net.embedder(torch.nn.functional.relu(fts_mod))

            t_forward_end = time.time()
            t_forward.append(t_forward_end - t_forward_start)

            # compute losses

            t_loss_start = time.time()

            kl_loss = kl_criterion(fts_mod, emb_mod, indices)
            kl_losses.update(kl_loss.data, len(data))

            idx_l2_fixed = torch.cat(
                [space_indices, sample_indices, fixpoint_indices])
            l2_loss = torch.mean(l2_criterion(
                emb_mod[:len(idx_l2_fixed)],
                target_embedding[idx_l2_fixed].type_as(emb_mod)),
                                 dim=1)
            # weight loss of space samples equally to all modified samples
            n_tot = len(l2_loss)
            n_space, n_sample, n_fixpoint = len(space_indices), len(
                sample_indices), len(fixpoint_indices)
            l2_loss = 0.8 * torch.mean(l2_loss[:n_space]) + 0.2 * torch.mean(
                l2_loss[n_space:])
            assert n_space + n_sample + n_fixpoint == n_tot

            l2_losses.update(l2_loss.data, len(idx_l2_fixed))

            loss = 0.6 * l2_loss + 0.4 * kl_loss.type_as(l2_loss)
            losses.update(loss.data, len(data))

            t_loss_end = time.time()
            t_loss.append(t_loss_end - t_loss_start)
            # backprop

            t_backprop_start = time.time()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t_backprop_end = time.time()
            t_backprop.append(t_backprop_end - t_backprop_start)

            # update

            t_update_start = time.time()

            # update current embedding
            new_embedding[indices] = emb_mod.data.cpu().numpy()

            t_update_end = time.time()
            t_update.append(t_update_end - t_update_start)

            if epoch > 5 and batch_idx >= 2 * len(neighbor_loader):
                print('\tend epoch after {} random fix point samples'.format(
                    (batch_idx + 1) * space_loader.batch_size))
                break

            t_tot_end = time.time()
            t_tot.append(t_tot_end - t_tot_start)

            t_load_start = time.time()

        print('Times:'
              '\n\tLoader: {})'
              '\n\tForward: {})'
              '\n\tLoss: {})'
              '\n\tBackprop: {})'
              '\n\tUpdate: {})'
              '\n\tTotal: {})'.format(
                  np.mean(t_load),
                  np.mean(t_forward),
                  np.mean(t_loss),
                  np.mean(t_backprop),
                  np.mean(t_update),
                  np.mean(t_tot),
              ))

        t_train_end = time.time()
        t_train.append(t_train_end - t_train_start)

        t_tensorboard_start = time.time()
        scheduler.step(losses.avg)
        log.write('l2_loss', float(l2_losses.avg), epoch, test=False)
        log.write('kl_loss', float(kl_losses.avg), epoch, test=False)
        log.write('loss', float(losses.avg), epoch, test=False)
        t_tensorboard_end = time.time()
        t_tensorboard.append(t_tensorboard_end - t_tensorboard_start)

        t_save_start = time.time()

        model_states[epoch] = {
            'epoch': epoch,
            'loss': losses.avg,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }
        embeddings[epoch] = new_embedding

        t_save_end = time.time()
        t_save.append(t_save_end - t_save_start)

        print('Train Epoch: {}\t'
              'Loss: {:.4f}\t'
              'KL Loss: {:.4f}\t'
              'LR: {:.6f}'.format(epoch, float(losses.avg),
                                  float(kl_losses.avg),
                                  optimizer.param_groups[-1]['lr']))

        t_send_start = time.time()

        # send to server
        if socket_id is not None:
            position = new_embedding if scale_func is None else scale_func(
                new_embedding)
            nodes = make_nodes(position=position, index=True)
            send_payload(nodes, socket_id)

        t_send_end = time.time()
        t_send.append(t_send_end - t_send_start)

        epoch += 1
        stop_criterion = track_l2_loss.add_value(l2_losses.avg)

        t_iter_end = time.time()
        t_iter.append(t_iter_end - t_iter_start)

    print('Times:'
          '\n\tBeta: {})'
          '\n\tTrain: {})'
          '\n\tTensorboard: {})'
          '\n\tSave: {})'
          '\n\tSend: {})'
          '\n\tIteration: {})'.format(
              np.mean(t_beta),
              np.mean(t_train),
              np.mean(t_tensorboard),
              np.mean(t_save),
              np.mean(t_send),
              np.mean(t_iter),
          ))

    print('Training details: '
          '\n\tMean: {}'
          '\n\tMax: {} ({})'
          '\n\tMin: {} ({})'.format(np.mean(t_train), np.max(t_train),
                                    np.argmax(t_train), np.min(t_train),
                                    np.argmin(t_train)))

    previously_modified = np.append(previously_modified, idx_modified)

    # compute new features
    new_features = get_feature(net.mapping, feature)

    print('Save output files...')
    # write output files for the cycle
    outfile_config = os.path.join(outpath_config,
                                  'cycle_{:03d}_config.pkl'.format(cycle))
    outfile_embedding = os.path.join(
        outpath_embedding, 'cycle_{:03d}_embeddings.hdf5'.format(cycle))
    outfile_feature = os.path.join(outpath_feature,
                                   'cycle_{:03d}_feature.hdf5'.format(cycle))
    outfile_model_states = os.path.join(
        outpath_model, 'cycle_{:03d}_models.pth.tar'.format(cycle))

    with h5py.File(outfile_embedding, 'w') as f:
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)
        for epoch in embeddings.keys():
            data = embeddings[epoch]
            f.create_dataset(name='epoch_{:04d}'.format(epoch),
                             shape=data.shape,
                             dtype=data.dtype,
                             data=data)

    with h5py.File(outfile_feature, 'w') as f:
        f.create_dataset(name='feature',
                         shape=new_features.shape,
                         dtype=new_features.dtype,
                         data=new_features)
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)

    torch.save(model_states, outfile_model_states)

    # write config file
    config_dict = {
        'idx_modified': idx_modified,
        'idx_old_neighbors': idx_old_neighbors,
        'idx_new_neighbors': idx_new_neighbors,
        'idx_high_dim_neighbors': idx_high_dim_neighbors
    }
    with open(outfile_config, 'w') as f:
        pickle.dump(config_dict, f)
    print('Done.')

    print('Finished training.')
示例#6
0
    def test(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]
            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)
            gtes.update(gte.data.cpu(), target[0].size(0))
            distances_ap.update(dist_ap.data.cpu(), target[0].size(0))
            distances_an.update(dist_an.data.cpu(), target[0].size(0))
            losses.update(loss.data[0].cpu(), target[0].size(0))

        print('\nVal set: Average loss: {:.4f} Average GTE {:.2f}%, '
              'Average non-zero triplets: {:d} LR: {:.6f}'.format(float(losses.avg), float(gtes.avg) * 100.,
                                                       int(non_zero_triplets.avg),
                                                                  optimizer.param_groups[-1]['lr']))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('gte', float(gtes.avg), epoch, test=True)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=True)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=True)
        log.write('dist_an', float(distances_an.avg), epoch, test=True)
        return losses.avg, 1 - gtes.avg
示例#7
0
    def train(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)           # do not compute ap pairs for concealed classes
            gtes.update(gte.data, target[0].size(0))
            distances_ap.update(dist_ap.data, target[0].size(0))
            distances_an.update(dist_an.data, target[0].size(0))
            losses.update(loss.data[0], target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'GTE: {:.2f}% ({:.2f}%)\t'
                      'Non-zero Triplets: {:d} ({:d})'.format(
                    epoch, batch_idx * len(target[0]), len(trainloader) * len(target[0]),
                    float(losses.val), float(losses.avg),
                    float(gtes.val) * 100., float(gtes.avg) * 100.,
                    int(non_zero_triplets.val), int(non_zero_triplets.avg)))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('gte', float(gtes.avg), epoch, test=False)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=False)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=False)
        log.write('dist_an', float(distances_an.avg), epoch, test=False)
def train(net,
          feature,
          image_id,
          old_embedding,
          target_embedding,
          idx_modified,
          idx_old_neighbors,
          idx_new_neighbors,
          idx_negatives,
          lr=1e-3,
          experiment_id=None,
          socket_id=None,
          scale_func=None,
          categories=None,
          label=None):
    global cycle, previously_modified
    cycle += 1
    # log and saving options
    exp_name = 'MapNet'

    if experiment_id is not None:
        exp_name = experiment_id + '_' + exp_name

    log = TBPlotter(os.path.join('runs/mapping', 'tensorboard', exp_name))
    log.print_logdir()

    outpath_config = os.path.join('runs/mapping', exp_name, 'configs')
    if not os.path.isdir(outpath_config):
        os.makedirs(outpath_config)
    outpath_embedding = os.path.join('runs/mapping', exp_name, 'embeddings')
    if not os.path.isdir(outpath_embedding):
        os.makedirs(outpath_embedding)
    outpath_feature = os.path.join('runs/mapping', exp_name, 'features')
    if not os.path.isdir(outpath_feature):
        os.makedirs(outpath_feature)
    outpath_model = os.path.join('runs/mapping', exp_name, 'models')
    if not os.path.isdir(outpath_model):
        os.makedirs(outpath_model)

    # general
    N = len(feature)
    use_cuda = torch.cuda.is_available()
    if not isinstance(old_embedding, torch.Tensor):
        old_embedding = torch.from_numpy(old_embedding.copy())
    if not isinstance(target_embedding, torch.Tensor):
        target_embedding = torch.from_numpy(target_embedding.copy())

    if use_cuda:
        net = net.cuda()
    net.train()

    # Set up differend groups of indices
    # each sample belongs to one group exactly, hierarchy is as follows:
    # 1: samples moved by user in this cycle
    # 2: negatives selected through neighbor method
    # 3: new neighborhood
    # 4: samples moved by user in previous cycles
    # 5: old neighborhood
    # 5: high dimensional neighborhood of moved samples
    # 6: fix points / unrelated (remaining) samples

    # # find high dimensional neighbors
    idx_high_dim_neighbors, _ = svm_k_nearest_neighbors(
        feature,
        np.union1d(idx_modified, idx_new_neighbors),
        negative_idcs=idx_negatives,
        k=100
    )  # use the first 100 nn of modified samples          # TODO: Better rely on distance

    # ensure there is no overlap between different index groups
    idx_modified = np.setdiff1d(
        idx_modified, idx_negatives
    )  # just ensure in case negatives have moved accidentially    TODO: BETTER FILTER BEFORE
    idx_new_neighbors = np.setdiff1d(
        idx_new_neighbors, np.concatenate([idx_modified, idx_negatives]))
    idx_previously_modified = np.setdiff1d(
        previously_modified,
        np.concatenate([idx_modified, idx_new_neighbors, idx_negatives]))
    idx_old_neighbors = np.setdiff1d(
        np.concatenate([idx_old_neighbors, idx_high_dim_neighbors]),
        np.concatenate([
            idx_modified, idx_new_neighbors, idx_previously_modified,
            idx_negatives
        ]))
    idx_fix_points = np.setdiff1d(
        range(N),
        np.concatenate([
            idx_modified, idx_new_neighbors, idx_previously_modified,
            idx_old_neighbors, idx_negatives
        ]))

    for i, g1 in enumerate([
            idx_modified, idx_new_neighbors, idx_previously_modified,
            idx_old_neighbors, idx_fix_points, idx_negatives
    ]):
        for j, g2 in enumerate([
                idx_modified, idx_new_neighbors, idx_previously_modified,
                idx_old_neighbors, idx_fix_points, idx_negatives
        ]):
            if i != j and len(np.intersect1d(g1, g2)) != 0:
                print('groups: {}, {}'.format(i, j))
                print(np.intersect1d(g1, g2))
                raise RuntimeError('Index groups overlap.')

    print('Group Overview:'
          '\n\tModified samples: {}'
          '\n\tNegative samples: {}'
          '\n\tNew neighbors: {}'
          '\n\tPreviously modified samples: {}'
          '\n\tOld neighbors: {}'
          '\n\tFix points: {}'.format(len(idx_modified), len(idx_negatives),
                                      len(idx_new_neighbors),
                                      len(idx_previously_modified),
                                      len(idx_old_neighbors),
                                      len(idx_fix_points)))

    # modify label
    label[idx_modified, -1] = 'modified'
    label[idx_negatives, -1] = 'negative'
    label[idx_previously_modified, -1] = 'prev_modified'
    label[idx_new_neighbors, -1] = 'new neighbors'
    label[idx_old_neighbors, -1] = 'old neighbors'
    label[idx_high_dim_neighbors, -1] = 'high dim neighbors'
    label[idx_fix_points, -1] = 'other'

    optimizer = torch.optim.Adam(
        [p for p in net.parameters() if p.requires_grad], lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           patience=5,
                                                           threshold=1e-3,
                                                           verbose=True)

    kl_criterion = TSNELoss(N, use_cuda=use_cuda)
    l2_criterion = torch.nn.MSELoss(reduction='none')  # keep the output fixed
    noise_criterion = NormalizedMSE()

    # define the index samplers for data

    batch_size = 500
    max_len = max(
        len(idx_modified) + len(idx_previously_modified), len(idx_negatives),
        len(idx_new_neighbors), len(idx_old_neighbors), len(idx_fix_points))
    if max_len == len(idx_fix_points):
        n_batches = max_len / (batch_size * 2) + 1
    else:
        n_batches = max_len / batch_size + 1

    sampler_modified = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_modified),
        batch_size=batch_size,
        drop_last=False)

    sampler_negatives = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_negatives),
        batch_size=batch_size,
        drop_last=False)

    sampler_new_neighbors = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_new_neighbors),
        batch_size=batch_size,
        drop_last=False)

    sampler_prev_modified = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_previously_modified),
        batch_size=batch_size,
        drop_last=False)

    sampler_old_neighbors = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_old_neighbors),
        batch_size=batch_size,
        drop_last=False)

    sampler_high_dim_neighbors = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_high_dim_neighbors),
        batch_size=batch_size,
        drop_last=False)

    sampler_fixed = torch.utils.data.BatchSampler(
        sampler=torch.utils.data.SubsetRandomSampler(idx_fix_points),
        batch_size=2 * batch_size,
        drop_last=False)

    # train network until scheduler reduces learning rate to threshold value
    lr_threshold = 1e-5
    track_l2_loss = ChangeRateLogger(n_track=5,
                                     threshold=5e-2,
                                     order='smaller')
    track_noise_reg = ChangeRateLogger(
        n_track=10, threshold=-1,
        order='smaller')  # only consider order --> negative threshold
    stop_criterion = False

    embeddings = {}
    model_states = {}
    cpu_net = copy.deepcopy(net).cpu() if use_cuda else net
    model_states[0] = {
        'epoch': 0,
        'loss': float('inf'),
        'state_dict': cpu_net.state_dict().copy(),
        'optimizer': optimizer.state_dict().copy(),
        'scheduler': scheduler.state_dict().copy()
    }
    embeddings[0] = old_embedding.numpy().copy()

    epoch = 1
    new_features = feature.copy()
    new_embedding = old_embedding.numpy().copy()

    t_beta = []
    t_train = []
    t_tensorboard = []
    t_save = []
    t_send = []
    t_iter = []

    tensor_feature = torch.from_numpy(feature)
    norms = torch.norm(tensor_feature, p=2, dim=1)
    feature_norm = torch.mean(norms)
    norm_margin = norms.std()
    norm_criterion = SoftNormLoss(norm_value=feature_norm, margin=norm_margin)
    # distance_criterion = NormalizedDistanceLoss()
    # distance_criterion = ContrastiveNormalizedDistanceLoss(margin=0.2 * feature_norm)
    triplet_margin = feature_norm
    triplet_selector = SemihardNegativeTripletSelector(
        margin=triplet_margin,
        cpu=False,
        preselect_index_positives=10,
        preselect_index_negatives=1,
        selection='random')
    distance_criterion = TripletLoss(margin=triplet_margin,
                                     triplet_selector=triplet_selector)
    negative_triplet_collector = []

    del norms
    while not stop_criterion:
        # if epoch < 30:           # do not use dropout at first
        #     net.eval()
        # else:
        net.train()

        t_iter_start = time.time()

        # compute beta for kl loss
        t_beta_start = time.time()
        kl_criterion._compute_beta(new_features)
        t_beta_end = time.time()
        t_beta.append(t_beta_end - t_beta_start)

        # set up losses
        l2_losses = AverageMeter()
        kl_losses = AverageMeter()
        distance_losses = AverageMeter()
        noise_regularization = AverageMeter()
        feature_norm = AverageMeter()
        norm_losses = AverageMeter()
        weight_regularization = AverageMeter()
        losses = AverageMeter()

        t_load = []
        t_forward = []
        t_loss = []
        t_backprop = []
        t_update = []
        t_tot = []

        # iterate over fix points (assume N_fixpoints >> N_modified)
        t_train_start = time.time()
        t_load_start = time.time()
        batch_loaders = []
        for smplr in [
                sampler_modified, sampler_negatives, sampler_new_neighbors,
                sampler_prev_modified, sampler_old_neighbors, sampler_fixed,
                sampler_high_dim_neighbors
        ]:
            batches = list(smplr)
            if len(batches) == 0:
                batches = [[] for i in range(n_batches)]
            while len(batches) < n_batches:
                to = min(n_batches - len(batches), len(batches))
                batches.extend(list(smplr)[:to])
            batch_loaders.append(batches)

        for batch_idx in range(n_batches):
            t_tot_start = time.time()

            moved_indices = batch_loaders[0][batch_idx]
            negatives_indices = batch_loaders[1][batch_idx]
            new_neigh_indices = batch_loaders[2][batch_idx]
            prev_moved_indices = batch_loaders[3][batch_idx]
            old_neigh_indices = batch_loaders[4][batch_idx]
            fixed_indices = batch_loaders[5][batch_idx]
            high_neigh_indices = batch_loaders[6][batch_idx]
            n_moved, n_neg, n_new, n_prev, n_old, n_fixed, n_high = (
                len(moved_indices), len(negatives_indices),
                len(new_neigh_indices), len(prev_moved_indices),
                len(old_neigh_indices), len(fixed_indices),
                len(high_neigh_indices))

            # load data
            indices = np.concatenate([
                new_neigh_indices, moved_indices, negatives_indices,
                prev_moved_indices, fixed_indices, old_neigh_indices,
                high_neigh_indices
            ]).astype(long)
            if len(indices) < 3 * kl_criterion.perplexity + 2:
                continue
            data = tensor_feature[indices]
            input = torch.autograd.Variable(
                data.cuda()) if use_cuda else torch.autograd.Variable(data)

            t_load_end = time.time()
            t_load.append(t_load_end - t_load_start)

            # compute forward
            t_forward_start = time.time()

            fts_mod = net.mapping(input)
            # fts_mod_noise = net.mapping(input + 0.1 * torch.rand(input.shape).type_as(input))
            fts_mod_noise = net.mapping(input +
                                        torch.rand(input.shape).type_as(input))
            emb_mod = net.embedder(torch.nn.functional.relu(fts_mod))

            t_forward_end = time.time()
            t_forward.append(t_forward_end - t_forward_start)

            # compute losses
            # modified --> KL, L2, Dist
            # new neighborhood --> KL, Dist
            # previously modified --> KL, L2
            # old neighborhood + high dimensional neighborhood --> KL
            # fix point samples --> KL, L2

            t_loss_start = time.time()

            noise_reg = noise_criterion(fts_mod, fts_mod_noise)
            noise_regularization.update(noise_reg.data, len(data))

            kl_loss = kl_criterion(fts_mod, emb_mod, indices)
            kl_losses.update(kl_loss.data, len(data))

            idx_l2_fixed = np.concatenate([
                new_neigh_indices, moved_indices, negatives_indices,
                prev_moved_indices, fixed_indices
            ]).astype(long)
            l2_loss = torch.mean(l2_criterion(
                emb_mod[:n_new + n_moved + n_neg + n_prev + n_fixed],
                target_embedding[idx_l2_fixed].type_as(emb_mod)),
                                 dim=1)
            # weigh loss of space samples equally to all modified samples
            l2_loss = 0.5 * torch.mean(l2_loss[:n_new + n_moved + n_neg + n_prev]) + \
                      0.5 * torch.mean(l2_loss[n_new + n_moved + n_neg + n_prev:])

            l2_losses.update(l2_loss.data, len(idx_l2_fixed))

            if epoch < 0:
                distance_loss = torch.tensor(0.)
            else:
                # distance_loss = distance_criterion(fts_mod[:n_new + n_moved])
                distance_loss_input = fts_mod[:-(n_old + n_high)] if (
                    n_old + n_high) > 0 else fts_mod
                distance_loss_target = torch.cat([
                    torch.ones(n_new + n_moved),
                    torch.zeros(n_neg + n_prev + n_fixed)
                ])
                distance_loss_weights = torch.cat([
                    torch.ones(n_new + n_moved + n_neg + n_prev),
                    0.5 * torch.ones(n_fixed)
                ])
                # also use high dimensional nn
                # distance_loss_weights = torch.cat([torch.ones(n_new+n_moved+n_neg+n_prev+n_fixed), 0.5*torch.ones(len(high_dim_nn))])

                # if len(high_dim_nn) > 0:
                #     distance_loss_input = torch.cat([distance_loss_input, fts_mod[high_dim_nn]])
                #     distance_loss_target = torch.cat([distance_loss_target, torch.ones(len(high_dim_nn))])
                if n_neg > 0:
                    selected_negatives = {
                        1: np.arange(n_new + n_moved, n_new + n_moved + n_neg)
                    }
                else:
                    selected_negatives = None

                distance_loss, negative_triplets = distance_criterion(
                    distance_loss_input,
                    distance_loss_target,
                    concealed_classes=[0],
                    weights=distance_loss_weights,
                    selected_negatives=selected_negatives)
                if negative_triplets is not None:
                    negative_triplets = np.unique(negative_triplets.numpy())
                    negative_triplets = indices[:-(n_old +
                                                   n_high)][negative_triplets]
                    negative_triplet_collector.extend(negative_triplets)
                distance_loss_noise, _ = distance_criterion(
                    distance_loss_input +
                    torch.rand(distance_loss_input.shape).type_as(
                        distance_loss_input),
                    distance_loss_target,
                    concealed_classes=[0])
                distance_loss = 0.5 * distance_loss + 0.5 * distance_loss_noise

            distance_losses.update(distance_loss.data, n_new + n_moved)

            # norm_loss = norm_criterion(torch.mean(fts_mod.norm(p=2, dim=1)))
            # norm_losses.update(norm_loss.data, len(data))

            weight_reg = torch.autograd.Variable(
                torch.tensor(0.)).type_as(l2_loss)
            for param in net.mapping.parameters():
                weight_reg += param.norm(1)
            weight_regularization.update(weight_reg, len(data))

            loss = 1 * distance_loss.type_as(l2_loss) + 5 * l2_loss + 10 * kl_loss.type_as(l2_loss) + \
                   1e-5 * weight_reg.type_as(l2_loss) #+ norm_loss.type_as(l2_loss)\ 1e3 * noise_reg.type_as(l2_loss)
            losses.update(loss.data, len(data))

            t_loss_end = time.time()
            t_loss.append(t_loss_end - t_loss_start)

            feature_norm.update(
                torch.mean(fts_mod.norm(p=2, dim=1)).data, len(data))

            # backprop

            t_backprop_start = time.time()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t_backprop_end = time.time()
            t_backprop.append(t_backprop_end - t_backprop_start)

            # update

            t_update_start = time.time()

            # update current embedding
            new_embedding[indices] = emb_mod.data.cpu().numpy()

            t_update_end = time.time()
            t_update.append(t_update_end - t_update_start)

            if epoch > 5 and (batch_idx + 1) * batch_size >= 2000:
                print('\tend epoch after {} random fix point samples'.format(
                    (batch_idx + 1) * batch_size))
                break

            t_tot_end = time.time()
            t_tot.append(t_tot_end - t_tot_start)

            t_load_start = time.time()

        # print('Times:'
        #       '\n\tLoader: {})'
        #       '\n\tForward: {})'
        #       '\n\tLoss: {})'
        #       '\n\tBackprop: {})'
        #       '\n\tUpdate: {})'
        #       '\n\tTotal: {})'.format(
        #     np.mean(t_load),
        #     np.mean(t_forward),
        #     np.mean(t_loss),
        #     np.mean(t_backprop),
        #     np.mean(t_update),
        #     np.mean(t_tot),
        # ))

        t_train_end = time.time()
        t_train.append(t_train_end - t_train_start)

        t_tensorboard_start = time.time()
        scheduler.step(losses.avg)
        label[np.unique(negative_triplet_collector), -1] = 'negative triplet'
        log.write('l2_loss', float(l2_losses.avg), epoch, test=False)
        log.write('distance_loss',
                  float(distance_losses.avg),
                  epoch,
                  test=False)
        log.write('kl_loss', float(kl_losses.avg), epoch, test=False)
        log.write('noise_regularization',
                  float(noise_regularization.avg),
                  epoch,
                  test=False)
        log.write('feature_norm', float(feature_norm.avg), epoch, test=False)
        log.write('norm_loss', float(norm_losses.avg), epoch, test=False)
        log.write('weight_reg',
                  float(weight_regularization.avg),
                  epoch,
                  test=False)
        log.write('loss', float(losses.avg), epoch, test=False)
        t_tensorboard_end = time.time()
        t_tensorboard.append(t_tensorboard_end - t_tensorboard_start)

        t_save_start = time.time()

        cpu_net = copy.deepcopy(net).cpu() if use_cuda else net

        model_states[epoch] = {
            'epoch': epoch,
            'loss': losses.avg.cpu(),
            'state_dict': cpu_net.state_dict().copy(),
            'optimizer': optimizer.state_dict().copy(),
            'scheduler': scheduler.state_dict().copy()
        }
        embeddings[epoch] = new_embedding

        t_save_end = time.time()
        t_save.append(t_save_end - t_save_start)

        print('Train Epoch: {}\t'
              'Loss: {:.4f}\t'
              'L2 Loss: {:.4f}\t'
              'Distance Loss: {:.4f}\t'
              'KL Loss: {:.4f}\t'
              'Noise Regularization: {:.4f}\t'
              'Weight Regularization: {:.4f}\t'
              'LR: {:.6f}'.format(epoch, float(losses.avg),
                                  float(5 * l2_losses.avg),
                                  float(0.5 * distance_losses.avg),
                                  float(10 * kl_losses.avg),
                                  float(noise_regularization.avg),
                                  float(1e-5 * weight_regularization.avg),
                                  optimizer.param_groups[-1]['lr']))

        t_send_start = time.time()

        # send to server
        if socket_id is not None:
            position = new_embedding if scale_func is None else scale_func(
                new_embedding)
            nodes = make_nodes(position=position, index=True, label=label)
            send_payload(nodes, socket_id, categories=categories)

        t_send_end = time.time()
        t_send.append(t_send_end - t_send_start)

        epoch += 1
        l2_stop_criterion = track_l2_loss.add_value(l2_losses.avg)
        epoch_stop_criterion = epoch > 150
        regularization_stop_criterion = False  #track_noise_reg.add_value(noise_regularization.avg)
        lr_stop_criterion = optimizer.param_groups[-1]['lr'] < lr_threshold
        stop_criterion = any([
            l2_stop_criterion, regularization_stop_criterion,
            lr_stop_criterion, epoch_stop_criterion
        ])

        t_iter_end = time.time()
        t_iter.append(t_iter_end - t_iter_start)

    print('Times:'
          '\n\tBeta: {})'
          '\n\tTrain: {})'
          '\n\tTensorboard: {})'
          '\n\tSave: {})'
          '\n\tSend: {})'
          '\n\tIteration: {})'.format(
              np.mean(t_beta),
              np.mean(t_train),
              np.mean(t_tensorboard),
              np.mean(t_save),
              np.mean(t_send),
              np.mean(t_iter),
          ))

    print('Training details: '
          '\n\tMean: {}'
          '\n\tMax: {} ({})'
          '\n\tMin: {} ({})'.format(np.mean(t_train), np.max(t_train),
                                    np.argmax(t_train), np.min(t_train),
                                    np.argmin(t_train)))

    previously_modified = np.append(previously_modified, idx_modified)

    # compute new features
    new_features = get_feature(net.mapping, feature)

    # print('Save output files...')
    # write output files for the cycle
    outfile_config = os.path.join(outpath_config,
                                  'cycle_{:03d}_config.pkl'.format(cycle))
    outfile_embedding = os.path.join(
        outpath_embedding, 'cycle_{:03d}_embeddings.hdf5'.format(cycle))
    outfile_feature = os.path.join(outpath_feature,
                                   'cycle_{:03d}_feature.hdf5'.format(cycle))
    outfile_model_states = os.path.join(
        outpath_model, 'cycle_{:03d}_models.pth.tar'.format(cycle))

    with h5py.File(outfile_embedding, 'w') as f:
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)
        for epoch in embeddings.keys():
            data = embeddings[epoch]
            f.create_dataset(name='epoch_{:04d}'.format(epoch),
                             shape=data.shape,
                             dtype=data.dtype,
                             data=data)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_embedding)))

    with h5py.File(outfile_feature, 'w') as f:
        f.create_dataset(name='feature',
                         shape=new_features.shape,
                         dtype=new_features.dtype,
                         data=new_features)
        f.create_dataset(name='image_id',
                         shape=image_id.shape,
                         dtype=image_id.dtype,
                         data=image_id)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_feature)))

    torch.save(model_states, outfile_model_states)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_model_states)))

    # write config file
    config_dict = {
        'idx_modified': idx_modified,
        'idx_old_neighbors': idx_old_neighbors,
        'idx_new_neighbors': idx_new_neighbors,
        'idx_high_dim_neighbors': idx_high_dim_neighbors
    }
    with open(outfile_config, 'w') as f:
        pickle.dump(config_dict, f)
    print('\tSaved {}'.format(os.path.join(os.getcwd(), outfile_config)))

    print('Done.')

    print('Finished training.')
    return new_embedding
    def test(epoch):
        kl_losses = AverageMeter()
        noise_regularization = AverageMeter()
        losses = AverageMeter()

        # switch to evaluation mode
        embedder.eval()
        for batch_idx, (fts, idx) in enumerate(test_loader):
            fts = torch.autograd.Variable(
                fts.cuda()) if use_cuda else torch.autograd.Variable(fts)
            outputs = embedder(fts)
            noise_outputs = embedder(fts +
                                     0.1 * torch.rand(fts.shape).type_as(fts))

            kl_loss = test_criterion(fts, outputs, idx)
            noise_reg = noise_criterion(outputs, noise_outputs)
            loss = kl_loss + 10 * noise_reg.type_as(kl_loss)

            kl_losses.update(kl_loss.data, len(idx))
            noise_regularization.update(noise_reg.data, len(idx))
            losses.update(loss.data, len(idx))

            if batch_idx % log_interval == 0:
                print('Test Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\n'.format(
                          epoch, (batch_idx + 1) * len(idx),
                          len(test_loader.sampler), float(losses.val),
                          float(losses.avg)))

        log.write('kl_loss', float(kl_losses.avg), epoch, test=True)
        log.write('noise_reg',
                  float(noise_regularization.avg),
                  epoch,
                  test=True)
        log.write('loss', float(losses.avg), epoch, test=True)
        return losses.avg