示例#1
0
def evaluate(data_source_f1, data_source_f2, batch_size=10):
    with open(os.path.join(args.save, 'validvars_loop'), 'rb') as f:
        # Aggregate exponents and outputs to calculate the validation set from the beginning of evaluation time
        [total_f1, total_f1_exps, total_f2, total_f2_exps,
         total_samples] = pickle.load(f)

    # Turn on evaluation mode which disables dropout.
    model_f1.eval()
    model_f2.eval()
    hidden_f1 = model_f1.init_hidden(batch_size, model_f1.ncell)
    hidden_f2 = model_f2.init_hidden(batch_size, model_f2.ncell)

    # Creating validation randata
    _val_tmp_randata = data.new_process(args, valid_length)
    # and setting uniform distribution for y~
    valid_randata_f1 = batchify_f1(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)
    valid_randata_f2 = batchify_f2(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)

    for i in range(0, data_source_f1.size(0) - 1, args.bptt):
        data_f1 = get_batch_dine(data_source_f1, i, args, evaluation=True)
        data_f2 = get_batch_dine(data_source_f2, i, args, evaluation=True)
        randata_f1 = get_batch_dine(valid_randata_f1, i, args, evaluation=True)
        randata_f2 = get_batch_dine(valid_randata_f2, i, args, evaluation=True)
        # forward
        with torch.no_grad():  # added this line to overcome OOM error
            out_f1, out_reused_f1, hidden_f1 = parallel_model_f1(
                data_f1, randata_f1, hidden_f1)
            out_f2, out_reused_f2, hidden_f2 = parallel_model_f2(
                data_f2, randata_f2, hidden_f2)

        # loss aggregation
        total_f1 += torch.sum(out_f1)
        total_f1_exps += torch.sum(torch.exp(out_reused_f1))
        total_f2 += torch.sum(out_f2)
        total_f2_exps += torch.sum(torch.exp(out_reused_f2))
        total_samples += torch.numel(out_f1)

        # hidden repackage
        hidden_f1 = repackage_hidden(hidden_f1)
        hidden_f2 = repackage_hidden(hidden_f2)

    return [total_f1, total_f1_exps, total_f2, total_f2_exps, total_samples]
示例#2
0
def evaluate(data_source_f1, data_source_f2, batch_size=10):
    # Turn on evaluation mode which disables dropout.
    model_f1.eval()
    model_f2.eval()
    total_loss_fi = [0, 0]
    hidden_f1 = model_f1.init_hidden(batch_size, model_f1.ncell)
    hidden_f2 = model_f2.init_hidden(batch_size, model_f2.ncell)

    # Creating validation randata
    _val_tmp_randata = data.new_process(args, valid_length)
    # and setting uniform distribution for y~
    valid_randata_f1 = batchify_f1(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)
    valid_randata_f2 = batchify_f2(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)

    for i in range(0, data_source_f1.size(0) - 1, args.bptt):
        data_f1 = get_batch_dine(data_source_f1, i, args, evaluation=True)
        data_f2 = get_batch_dine(data_source_f2, i, args, evaluation=True)
        randata_f1 = get_batch_dine(valid_randata_f1, i, args, evaluation=True)
        randata_f2 = get_batch_dine(valid_randata_f2, i, args, evaluation=True)
        # forward
        out_f1, out_reused_f1, hidden_f1 = parallel_model_f1(
            data_f1, randata_f1, hidden_f1)
        out_f2, out_reused_f2, hidden_f2 = parallel_model_f2(
            data_f2, randata_f2, hidden_f2)

        # loss calculation
        raw_loss_f1 = torch.mean(out_f1) - torch.log(
            torch.mean(torch.exp(out_reused_f1)))
        raw_loss_f2 = torch.mean(out_f2) - torch.log(
            torch.mean(torch.exp(out_reused_f2)))
        total_loss_fi[0] += raw_loss_f1.data * len(data_f1)
        total_loss_fi[1] += raw_loss_f2.data * len(data_f2)

        # hidden repackage
        hidden_f1 = repackage_hidden(hidden_f1)
        hidden_f2 = repackage_hidden(hidden_f2)

    total_loss_fi[0] = total_loss_fi[0].item() / (data_source_f1.size(0))
    total_loss_fi[1] = total_loss_fi[1].item() / (data_source_f2.size(0))

    return total_loss_fi
示例#3
0
def train():
    assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size'

    # Creating train data
    _train_tmp_data = data.new_process(args, train_length)
    train_data_f1 = batchify_f1(_train_tmp_data, args.batch_size, args)
    train_data_f2 = batchify_f2(_train_tmp_data, args.batch_size, args)

    # Turn on training mode which enables dropout.
    total_loss_fi = [0, 0]
    start_time = time.time()
    hidden_f1 = [
        model_f1.init_hidden(args.small_batch_size, args.ncell)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    hidden_f2 = [
        model_f2.init_hidden(args.small_batch_size, args.ncell)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    # Shuffeling data for y~ (ONLY y!)
    train_randata_f1 = batchify_f1(_train_tmp_data,
                                   args.batch_size,
                                   args,
                                   uniformly=True)
    train_randata_f2 = batchify_f2(_train_tmp_data,
                                   args.batch_size,
                                   args,
                                   uniformly=True)

    batch, i = 0, 0
    while i < train_data_f1.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence length resulting in OOM
        seq_len = min(seq_len, args.bptt + args.max_seq_len_delta)
        # Adjusting learning rate to variable length sequence
        lr_tmp_f1 = optimizer_f1.param_groups[0]['lr']
        optimizer_f1.param_groups[0]['lr'] = lr_tmp_f1 * seq_len / args.bptt
        lr_tmp_f2 = optimizer_f2.param_groups[0]['lr']
        optimizer_f2.param_groups[0]['lr'] = lr_tmp_f2 * seq_len / args.bptt

        # Activating train mode
        model_f1.train()
        model_f2.train()

        data_f1 = get_batch_dine(train_data_f1, i, args, seq_len=seq_len)
        data_f2 = get_batch_dine(train_data_f2, i, args, seq_len=seq_len)
        randata_f1 = get_batch_dine(train_randata_f1, i, args, seq_len=seq_len)
        randata_f2 = get_batch_dine(train_randata_f2, i, args, seq_len=seq_len)

        optimizer_f1.zero_grad()
        optimizer_f2.zero_grad()

        start, end, s_id = 0, args.small_batch_size, 0
        while start < args.batch_size:
            cur_data_f1 = data_f1[:, start:end]
            cur_data_f2 = data_f2[:, start:end]
            cur_randata_f1 = randata_f1[:, start:end]
            cur_randata_f2 = randata_f2[:, start:end]

            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            hidden_f1[s_id] = repackage_hidden(hidden_f1[s_id])
            hidden_f2[s_id] = repackage_hidden(hidden_f2[s_id])

            out_f1, out_reused_f1, hidden_f1[s_id] = parallel_model_f1(
                cur_data_f1, cur_randata_f1, hidden_f1[s_id])
            out_f2, out_reused_f2, hidden_f2[s_id] = parallel_model_f2(
                cur_data_f2, cur_randata_f2, hidden_f2[s_id])

            raw_loss_f1 = torch.mean(out_f1) - torch.log(
                torch.mean(torch.exp(out_reused_f1)))
            raw_loss_f2 = torch.mean(out_f2) - torch.log(
                torch.mean(torch.exp(out_reused_f2)))
            loss = [raw_loss_f1, raw_loss_f2]

            loss[0] *= args.small_batch_size / args.batch_size
            loss[1] *= args.small_batch_size / args.batch_size
            total_loss_fi[
                0] += raw_loss_f1.data * args.small_batch_size / args.batch_size
            total_loss_fi[
                1] += raw_loss_f2.data * args.small_batch_size / args.batch_size
            (-loss[0]).backward()  # for gradient ascent we use -loss
            (-loss[1]).backward()

            # optimizer grad check... note: comment this if 'nan' problem solved
            for j, p in enumerate(optimizer_f1.param_groups[0]['params']):
                ans = p.grad.data[torch.isnan(p.grad.data)]
                if ans.size()[0] > 0:
                    logging.info(
                        "some nan exists in optimizer_f1 parameter #{} !".
                        format(j))
                    raise KeyboardInterrupt
            for j, p in enumerate(optimizer_f2.param_groups[0]['params']):
                ans = p.grad.data[torch.isnan(p.grad.data)]
                if ans.size()[0] > 0:
                    logging.info(
                        "some nan exists in optimizer_f2 parameter #{} !".
                        format(j))
                    raise KeyboardInterrupt

            s_id += 1
            start = end
            end = start + args.small_batch_size

            gc.collect()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model_f1.parameters(), args.clip)
        torch.nn.utils.clip_grad_norm_(model_f2.parameters(), args.clip)

        optimizer_f1.step()
        optimizer_f2.step()

        #learning rate - back to normal
        optimizer_f1.param_groups[0]['lr'] = lr_tmp_f1
        optimizer_f2.param_groups[0]['lr'] = lr_tmp_f2

        if batch < 1:
            log_interval = np.ceil((len(train_data_f1) // args.bptt) / 5)
        if batch % log_interval == 0 and batch > 0:
            curr_loss_fi = [
                total_loss_fi[0].item() / args.log_interval,
                total_loss_fi[1].item() / args.log_interval
            ]
            elapsed = time.time() - start_time
            logging.info(
                '| epoch {:3d} | {:3d}/{:3d} batches | lr1 {:f} | lr2 {:f} | ms/batch {:5.4f} | '
                'loss F1 {:6.4f} | loss F2 {:6.4f} | loss {:6.4f}'.format(
                    epoch, batch,
                    len(train_data_f1) // args.bptt,
                    optimizer_f1.param_groups[0]['lr'],
                    optimizer_f2.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, curr_loss_fi[0],
                    curr_loss_fi[1], curr_loss_fi[1] - curr_loss_fi[0]))
            total_loss_fi = [0, 0]
            start_time = time.time()
        ###
        batch += 1
        i += seq_len
示例#4
0
        else:  # assuming it's Adam
            optimizer_f1 = torch.optim.Adam(model_f1.parameters(),
                                            lr=args.lr1,
                                            weight_decay=args.wdecay)
            optimizer_f2 = torch.optim.Adam(model_f2.parameters(),
                                            lr=args.lr2,
                                            weight_decay=args.wdecay)
        optimizer_f1.load_state_dict(optimizer_state_f1)
        optimizer_f2.load_state_dict(optimizer_state_f2)
    else:
        assert False, 'loop only on existing model'

    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        # Creating validation data
        _val_tmp_data = data.new_process(args, valid_length)
        val_data_f1 = batchify_f1(_val_tmp_data, eval_batch_size, args)
        val_data_f2 = batchify_f2(_val_tmp_data, eval_batch_size, args)
        # Evaluation
        val_loss = evaluate(val_data_f1,
                            val_data_f2,
                            batch_size=eval_batch_size)
        logging.info('Using saved evaluation data - validation loop only!')
        with open(os.path.join(args.save, 'validvars_loop'), 'wb') as f:
            # saving the current aggregated result:
            pickle.dump(
                val_loss, f
            )  # = [total_f1, total_f1_exps, total_f2, total_f2_exps, total_samples]
        f1 = torch.div(val_loss[0], val_loss[-1]) - torch.log(
            torch.div(val_loss[1], val_loss[-1]))
        f2 = torch.div(val_loss[2], val_loss[-1]) - torch.log(
示例#5
0
            # hidden repackage
            hidden_f2 = repackage_hidden(hidden_f2)

        y_givenx[k] = torch.mean(dist_vector).detach().cpu().numpy()

    y_givenx = y_givenx / np.sum(y_givenx)
    return y_vec, y_givenx


# Load the best saved model.
model_f2 = torch.load(os.path.join(args.save, 'model_f2.pt'))
parallel_model_f2 = nn.DataParallel(model_f2, dim=1).cuda()

# calculating the support
_val_tmp_randata = data.new_process(args, valid_length)
support = torch.max(_val_tmp_randata['labels']) - torch.min(
    _val_tmp_randata['labels'])
# Run on test data.
y_vec, y_givenx = evaluate_dist(x_inp, support)

plt.rcParams['axes.facecolor'] = 'floralwhite'
plt.fill_between(y_vec, y_givenx)
plt.grid(True, which='both', axis='both')
if db != -999:
    plt.title('DINE Trained Model - ' + r'$P_{Y|X}$' + ' of {}\n'
              'X={}, P={}, N={}, SNR {}dB dim=1'.format(
                  process, x_inp, args.P, args.N, db))
else:
    plt.title('DINE Trained Model - ' + r'$P_{Y|X}$' + ' of {}\n'
              'X={}, P={}, N={}, dim=1'.format(process, x_inp, args.P, args.N))