def _init_network(opts, samplers):
    """Initialize the network."""
    feat_dims = samplers[0].feat_dims
    # label_dims = samplers[0].label_dims
    label_dims = 2

    if opts["flags"].arch == "concat":
        network = hantman_hungarian.HantmanHungarianConcat(
            input_dims=feat_dims,
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=label_dims
        )
        # network = torch.nn.DataParallel(network, device_ids=[0, 2])
    elif opts["flags"].arch == "sum":
        network = hantman_hungarian.HantmanHungarianSum(
            input_dims=feat_dims,
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=label_dims
        )
    else:
        network = hantman_hungarian.HantmanHungarianBidirConcat(
            input_dims=feat_dims,
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=label_dims
        )
    # create the optimizer too
    optimizer = torch.optim.Adam(
        network.parameters(), lr=opts["flags"].learning_rate)

    if opts["flags"].cuda_device != -1:
        network.cuda()

    return network, optimizer
def _init_network(opts, h5_data, label_weight):
    """Setup the network."""
    exp_list = h5_data["exp_names"].value
    opts["feat_dims"] = [
        train_data["exps"][exp_list[0]][feat_key].shape[2]
        for feat_key in opts["flags"].feat_keys
    ]
    # num_input = h5_data["exps"][exp_list[0]]["reduced"].shape[2]
    num_classes = h5_data["exps"][exp_list[0]]["labels"].shape[2]
    if opts["flags"].arch == "concat":
        network = hantman_hungarian.HantmanHungarianConcat(
            input_dims=opts["feat_dims"],
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=num_classes,
            label_weight=label_weight
        )
        # network = torch.nn.DataParallel(network, device_ids=[0, 2])
    elif opts["flags"].arch == "sum":
        network = hantman_hungarian.HantmanHungarianSum(
            input_dims=opts["feat_dims"],
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=num_classes,
            label_weight=label_weight
        )
    else:
        network = hantman_hungarian.HantmanHungarianBidirConcat(
            input_dims=opts["feat_dims"],
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=num_classes,
            label_weight=label_weight
        )

    # create the optimizer too
    optimizer = torch.optim.Adam(
        network.parameters(), lr=opts["flags"].learning_rate)

    if opts["flags"].cuda_device != -1:
        network.cuda()

    return network, optimizer
def _init_network(opts, samplers, label_weights):
    """Initialize the network."""
    feat_dims = samplers[0].feat_dims
    label_dims = samplers[0].label_dims

    # compute the number of iterations per epoch.
    num_exp = len(samplers[0].exp_names)
    # iter_per_epoch =\
    #     np.ceil(1.0 * num_exp / opts["flags"].mini_batch)
    iter_per_epoch = 1.0 * num_exp / opts["flags"].mini_batch
    opts["flags"].perframe_decay_step = iter_per_epoch * opts[
        "flags"].perframe_decay_step
    opts["flags"].iter_per_epoch = iter_per_epoch
    # import pdb; pdb.set_trace()

    # initialize the network
    if opts["flags"].hantman_arch == "concat":
        network = hantman_hungarian.HantmanHungarianConcat(
            input_dims=feat_dims,
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=label_dims)
    elif opts["flags"].hantman_arch == "sum":
        network = hantman_hungarian.HantmanHungarianSum(
            input_dims=feat_dims,
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=label_dims)
    else:
        network = hantman_hungarian.HantmanHungarianBidirConcat(
            input_dims=feat_dims,
            hidden_dim=opts["flags"].lstm_hidden_dim,
            output_dim=label_dims)
    network.load_state_dict(torch.load(opts["flags"].model))
    # create the optimizer too
    # optimizer = torch.optim.Adam(
    #     network.parameters(), lr=opts["flags"].learning_rate)
    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=opts["flags"].learning_rate,
                                 weight_decay=0.00001)

    # next the criterion
    if opts["flags"].loss == "mse":
        # criterion = torch.nn.MSELoss(size_average=False).cuda()
        temp = torch.nn.MSELoss(size_average=False)
        if opts["flags"].cuda_device != -1:
            temp.cuda()

        def criterion(step, y, yhat, pos_mask, neg_mask, frame_mask):
            return temp(y, yhat)

        # criterion = lambda y, yhat, pos_mask, neg_mask:\
        #     temp(y, yhat)
    elif opts["flags"].loss == "weighted_mse":
        criterion = construct_weigthed_mse(opts, label_weights)
    elif opts["flags"].loss == "hungarian":
        criterion = construct_hungarian(opts, label_weights)
    elif opts["flags"].loss == "wasserstein":
        criterion = construct_wasserstein(opts, label_weights)

    if opts["flags"].cuda_device != -1:
        network.cuda()
    # import pdb; pdb.set_trace()
    return network, optimizer, criterion