def _train_epoch(opts, step, network, optimizer, train_data, test_data,
                 label_weight):
    """Train one epoch."""
    # loss_fp_fn = os.path.join(opts["flags"].out_dir, "plots", "loss_fp.csv")
    train_exps = train_data["exp_names"].value
    train_exps = opts["rng"].permutation(train_exps)
    round_tic = time.time()
    batch_id = 0

    # inputs, org_labels, sample_idx, batch_id = _get_seq_mini_batch(
    #     opts, batch_id, train_data, train_exps)
    while batch_id != -1:
        # inputs, org_labels, sample_idx, batch_id = _get_seq_mini_batch(
        #     opts, batch_id, train_data, train_exps)
        inputs, labels, mask, org_labels, sample_idx, batch_id =\
            _get_seq_mini_batch(opts, batch_id, train_data, train_exps)
        # train_predict = network["predict_batch"](inputs[0])
        hidden = _get_hidden(opts)
        # img_side = torch.autograd.Variable(torch.Tensor(inputs[0])).cuda()
        # img_front = torch.autograd.Variable(torch.Tensor(inputs[1])).cuda()
        inputs = [
            torch.autograd.Variable(torch.Tensor(feats),
                                    requires_grad=True).cuda()
            for feats in inputs
        ]
        labels = torch.autograd.Variable(torch.Tensor(labels),
                                         requires_grad=False).cuda()
        mask = torch.autograd.Variable(torch.Tensor(mask),
                                       requires_grad=False).cuda()

        train_predict, update_hid = network(inputs, hidden)

        TP_weight, FP_weight, false_neg, false_pos = create_match_array(
            opts, train_predict, org_labels, label_weight[2])

        pos_mask, neg_mask = hantman_hungarian.create_pos_neg_masks(
            labels, label_weight[0], label_weight[1])
        perframe_cost = hantman_hungarian.perframe_loss(
            train_predict, mask, labels, pos_mask, neg_mask)
        tp_cost, fp_cost, fn_cost = hantman_hungarian.structured_loss(
            train_predict, mask, TP_weight, FP_weight, false_neg)

        total_cost, struct_cost, perframe_cost, tp_cost, fp_cost, fn_cost =\
            hantman_hungarian.combine_losses(opts, step, perframe_cost, tp_cost, fp_cost, fn_cost)
        cost = total_cost.mean()
        optimizer.zero_grad()
        cost.backward()
        # torch.nn.utils.clip_grad_norm(network.parameters(), 5)

        optimizer.step()
        step += 1

    return step
def hungarian_loss(opts, step, y, yhat, pos_mask, neg_mask, mask, pos_weight,
                   neg_weight):
    """Hungarian loss"""
    # figure out the matches.
    TP_weight, FP_weight, num_false_neg, num_false_pos = create_match_array(
        opts, yhat, y, pos_weight, neg_weight)

    seq_len = pos_mask.shape[0]
    mini_batch = pos_mask.shape[1]
    pos_weight = pos_weight.repeat([seq_len, mini_batch, 1])

    pos_mask, neg_mask = hantman_hungarian.create_pos_neg_masks(
        y, pos_weight, neg_weight)
    perframe_cost = hantman_hungarian.perframe_loss(yhat, mask, y, pos_mask,
                                                    neg_mask)
    tp_cost, fp_cost, fn_cost = hantman_hungarian.structured_loss(
        yhat, mask, pos_weight, TP_weight, FP_weight, num_false_neg)

    total_cost, struct_cost, perframe_cost, tp_cost, fp_cost, fn_cost =\
        hantman_hungarian.combine_losses(opts, step, perframe_cost, tp_cost, fp_cost, fn_cost)
    cost = total_cost.mean()

    return cost
def _predict_write(opts, step, out_dir, network, h5_data, exps, label_weight):
    """Predict and write sequence classifications."""
    batch_id = 0
    exps.sort()
    loss = 0
    scores = [0, 0, 0, 0]
    batch_count = 0
    # t = network["lr_update"]["params"][0]
    t = step
    while batch_id != -1:
        if batch_count % 10 == 0:
            print("\t\t%d" % batch_count)
        # inputs, org_labels, sample_idx, batch_id = _get_seq_mini_batch(
        #     opts, batch_id, h5_data, exps)
        inputs, labels, mask, org_labels, sample_idx, batch_id =\
            _get_seq_mini_batch(opts, batch_id, h5_data, exps)

        hidden = _get_hidden(opts)
        # img_side = torch.autograd.Variable(torch.Tensor(inputs[0])).cuda()
        # img_front = torch.autograd.Variable(torch.Tensor(inputs[1])).cuda()
        inputs = [
            torch.autograd.Variable(torch.Tensor(feats)).cuda()
            for feats in inputs
        ]
        labels = torch.autograd.Variable(torch.Tensor(labels)).cuda()
        mask = torch.autograd.Variable(torch.Tensor(mask)).cuda()

        predict, update_hid = network(inputs, hidden)

        TP_weight, FP_weight, false_neg, false_pos = create_match_array(
            opts, predict, org_labels, label_weight[2])

        pos_mask, neg_mask = hantman_hungarian.create_pos_neg_masks(labels, label_weight[0], label_weight[1])
        perframe_cost = hantman_hungarian.perframe_loss(predict, mask, labels, pos_mask, neg_mask)
        tp_cost, fp_cost, fn_cost = hantman_hungarian.structured_loss(
            predict, mask, TP_weight, FP_weight, false_neg)

        total_cost, struct_cost, perframe_cost, tp_cost, fp_cost, fn_cost =\
            hantman_hungarian.combine_losses(opts, step, perframe_cost, tp_cost, fp_cost, fn_cost)
        cost = total_cost.mean()

        loss += cost.data[0]
        # order from past:
        # total cost, struct_cost, tp, fp, fn, perframe
        scores[0] += tp_cost.data.cpu()[0]
        scores[1] += fp_cost.data.cpu()[0]
        scores[2] += fn_cost.data.cpu()[0]
        scores[3] += perframe_cost.data.cpu()[0]

        # scores = [scores[i] + cost[i + 3] for i in range(len(cost[3:]))]
        predictions = predict.data.cpu().numpy()

        # collect the labels
        labels = []
        frames = []
        for vid in exps[sample_idx]:
            labels.append(h5_data["exps"][vid]["labels"].value)
            frames.append(list(range(h5_data["exps"][vid]["labels"].shape[0])))

        # idx = idx[:valid_idx]
        # print feat_idx
        # print idx
        # print inputs
        # import pdb; pdb.set_trace()
        sequences_helper.write_predictions2(
            out_dir, exps[sample_idx], predictions, labels,
            [], frames)

        batch_count = batch_count + 1

    loss = loss / batch_count
    scores = [score / batch_count for score in scores]

    return loss, scores