コード例 #1
0
ファイル: main_validloop.py プロジェクト: omerlux/DINE
def evaluate(data_source_f1, data_source_f2, batch_size=10):
    with open(os.path.join(args.save, 'validvars_loop'), 'rb') as f:
        # Aggregate exponents and outputs to calculate the validation set from the beginning of evaluation time
        [total_f1, total_f1_exps, total_f2, total_f2_exps,
         total_samples] = pickle.load(f)

    # Turn on evaluation mode which disables dropout.
    model_f1.eval()
    model_f2.eval()
    hidden_f1 = model_f1.init_hidden(batch_size, model_f1.ncell)
    hidden_f2 = model_f2.init_hidden(batch_size, model_f2.ncell)

    # Creating validation randata
    _val_tmp_randata = data.new_process(args, valid_length)
    # and setting uniform distribution for y~
    valid_randata_f1 = batchify_f1(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)
    valid_randata_f2 = batchify_f2(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)

    for i in range(0, data_source_f1.size(0) - 1, args.bptt):
        data_f1 = get_batch_dine(data_source_f1, i, args, evaluation=True)
        data_f2 = get_batch_dine(data_source_f2, i, args, evaluation=True)
        randata_f1 = get_batch_dine(valid_randata_f1, i, args, evaluation=True)
        randata_f2 = get_batch_dine(valid_randata_f2, i, args, evaluation=True)
        # forward
        with torch.no_grad():  # added this line to overcome OOM error
            out_f1, out_reused_f1, hidden_f1 = parallel_model_f1(
                data_f1, randata_f1, hidden_f1)
            out_f2, out_reused_f2, hidden_f2 = parallel_model_f2(
                data_f2, randata_f2, hidden_f2)

        # loss aggregation
        total_f1 += torch.sum(out_f1)
        total_f1_exps += torch.sum(torch.exp(out_reused_f1))
        total_f2 += torch.sum(out_f2)
        total_f2_exps += torch.sum(torch.exp(out_reused_f2))
        total_samples += torch.numel(out_f1)

        # hidden repackage
        hidden_f1 = repackage_hidden(hidden_f1)
        hidden_f2 = repackage_hidden(hidden_f2)

    return [total_f1, total_f1_exps, total_f2, total_f2_exps, total_samples]
コード例 #2
0
def evaluate(data_source_f1, data_source_f2, batch_size=10):
    # Turn on evaluation mode which disables dropout.
    model_f1.eval()
    model_f2.eval()
    total_loss_fi = [0, 0]
    hidden_f1 = model_f1.init_hidden(batch_size, model_f1.ncell)
    hidden_f2 = model_f2.init_hidden(batch_size, model_f2.ncell)

    # Creating validation randata
    _val_tmp_randata = data.new_process(args, valid_length)
    # and setting uniform distribution for y~
    valid_randata_f1 = batchify_f1(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)
    valid_randata_f2 = batchify_f2(_val_tmp_randata,
                                   batch_size,
                                   args,
                                   uniformly=True)

    for i in range(0, data_source_f1.size(0) - 1, args.bptt):
        data_f1 = get_batch_dine(data_source_f1, i, args, evaluation=True)
        data_f2 = get_batch_dine(data_source_f2, i, args, evaluation=True)
        randata_f1 = get_batch_dine(valid_randata_f1, i, args, evaluation=True)
        randata_f2 = get_batch_dine(valid_randata_f2, i, args, evaluation=True)
        # forward
        out_f1, out_reused_f1, hidden_f1 = parallel_model_f1(
            data_f1, randata_f1, hidden_f1)
        out_f2, out_reused_f2, hidden_f2 = parallel_model_f2(
            data_f2, randata_f2, hidden_f2)

        # loss calculation
        raw_loss_f1 = torch.mean(out_f1) - torch.log(
            torch.mean(torch.exp(out_reused_f1)))
        raw_loss_f2 = torch.mean(out_f2) - torch.log(
            torch.mean(torch.exp(out_reused_f2)))
        total_loss_fi[0] += raw_loss_f1.data * len(data_f1)
        total_loss_fi[1] += raw_loss_f2.data * len(data_f2)

        # hidden repackage
        hidden_f1 = repackage_hidden(hidden_f1)
        hidden_f2 = repackage_hidden(hidden_f2)

    total_loss_fi[0] = total_loss_fi[0].item() / (data_source_f1.size(0))
    total_loss_fi[1] = total_loss_fi[1].item() / (data_source_f2.size(0))

    return total_loss_fi
コード例 #3
0
def evaluate_dist(x, support):
    y_vec = np.arange(x - 7.5, x + 7.5, 0.05)
    y_givenx = np.zeros(len(y_vec))
    for k, y in enumerate(y_vec):
        batch_size = 1
        # Creating test data - for x,y respectively
        data_xy = {
            'features': (torch.ones(50, batch_size, args.ndim) * x),
            'labels': (torch.ones(50, batch_size, args.ndim) * y)
        }
        test_data_f2 = batchify_f2(data_xy, batch_size, args)
        # setting uniform distribution for y~
        test_randata_f2 = batchify_f2(data_xy,
                                      batch_size,
                                      args,
                                      uniformly=True)

        # Turn on evaluation mode which disables dropout.
        model_f2.eval()
        hidden_f2 = model_f2.init_hidden(batch_size, model_f2.ncell)
        dist_vector = torch.FloatTensor()

        for i in range(0, test_data_f2.size(0) - 1, args.bptt):
            data_f2 = get_batch_dine(test_data_f2, i, args, evaluation=True)
            randata_f2 = get_batch_dine(test_randata_f2,
                                        i,
                                        args,
                                        evaluation=True)
            # forward
            out_f2, out_reused_f2, hidden_f2 = parallel_model_f2(
                data_f2, randata_f2, hidden_f2)
            # distribution calculation
            if i:
                dist = torch.exp(out_f2) * (1 / support)
                dist_vector = torch.cat((dist_vector, dist), 0)
            else:
                dist_vector = torch.exp(out_f2) * (1 / support)

            # hidden repackage
            hidden_f2 = repackage_hidden(hidden_f2)

        y_givenx[k] = torch.mean(dist_vector).detach().cpu().numpy()

    y_givenx = y_givenx / np.sum(y_givenx)
    return y_vec, y_givenx
コード例 #4
0
ファイル: main_dine.py プロジェクト: omerlux/DINE
def evaluate(data_source_f1, data_source_f2, batch_size=10):
    with open(os.path.join(args.save, 'validvars'), 'rb') as f:
        # Aggregate exponents and outputs to calculate the validation set from the beginning of evaluation time
        [total_f1, total_f1_exps, total_f2, total_f2_exps,
         total_samples] = pickle.load(f)

    # Turn on evaluation mode which disables dropout.
    model_f1.eval()
    model_f2.eval()
    hidden_f1 = model_f1.init_hidden(batch_size, model_f1.ncell)
    hidden_f2 = model_f2.init_hidden(batch_size, model_f2.ncell)

    # Adjusting parameters for uniformly taken samples
    model_f1.rand_num = model_f2.rand_num = 0  # args.ndim * 1
    model_f1.sup_min = model_f2.sup_min = torch.min(data_source_f1)
    model_f1.sup_max = model_f2.sup_max = torch.max(data_source_f1)

    for i in range(0, data_source_f1.size(0) - 1, args.bptt):
        data_f1 = get_batch_dine(data_source_f1, i, args, evaluation=True)
        data_f2 = get_batch_dine(data_source_f2, i, args, evaluation=True)
        # forward
        with torch.no_grad():  # added this line to overcome OOM error
            out_f1, out_reused_f1, hidden_f1 = parallel_model_f1(
                data_f1, hidden_f1)
            out_f2, out_reused_f2, hidden_f2 = parallel_model_f2(
                data_f2, hidden_f2)

        # loss aggregation
        total_f1 += torch.sum(out_f1)
        total_f1_exps += torch.sum(torch.exp(out_reused_f1))
        total_f2 += torch.sum(out_f2)
        total_f2_exps += torch.sum(torch.exp(out_reused_f2))
        total_samples += torch.numel(out_f1)

        # hidden repackage
        hidden_f1 = repackage_hidden(hidden_f1)
        hidden_f2 = repackage_hidden(hidden_f2)

    return [total_f1, total_f1_exps, total_f2, total_f2_exps, total_samples]
コード例 #5
0
ファイル: main_validloop.py プロジェクト: omerlux/DINE
def train():
    assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size'

    # Creating train data
    _train_tmp_data = data.new_process(args, train_length)
    train_data_f1 = batchify_f1(_train_tmp_data, args.batch_size, args)
    train_data_f2 = batchify_f2(_train_tmp_data, args.batch_size, args)

    # Turn on training mode which enables dropout.
    total_loss_fi = [0, 0]
    start_time = time.time()
    hidden_f1 = [
        model_f1.init_hidden(args.small_batch_size, args.ncell)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    hidden_f2 = [
        model_f2.init_hidden(args.small_batch_size, args.ncell)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    # Shuffeling data for y~ (ONLY y!)
    train_randata_f1 = batchify_f1(_train_tmp_data,
                                   args.batch_size,
                                   args,
                                   uniformly=True)
    train_randata_f2 = batchify_f2(_train_tmp_data,
                                   args.batch_size,
                                   args,
                                   uniformly=True)

    batch, i = 0, 0
    while i < train_data_f1.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence length resulting in OOM
        seq_len = min(seq_len, args.bptt + args.max_seq_len_delta)
        # Adjusting learning rate to variable length sequence
        lr_tmp_f1 = optimizer_f1.param_groups[0]['lr']
        optimizer_f1.param_groups[0]['lr'] = lr_tmp_f1 * seq_len / args.bptt
        lr_tmp_f2 = optimizer_f2.param_groups[0]['lr']
        optimizer_f2.param_groups[0]['lr'] = lr_tmp_f2 * seq_len / args.bptt

        # Activating train mode
        model_f1.train()
        model_f2.train()

        data_f1 = get_batch_dine(train_data_f1, i, args, seq_len=seq_len)
        data_f2 = get_batch_dine(train_data_f2, i, args, seq_len=seq_len)
        randata_f1 = get_batch_dine(train_randata_f1, i, args, seq_len=seq_len)
        randata_f2 = get_batch_dine(train_randata_f2, i, args, seq_len=seq_len)

        optimizer_f1.zero_grad()
        optimizer_f2.zero_grad()

        start, end, s_id = 0, args.small_batch_size, 0
        while start < args.batch_size:
            cur_data_f1 = data_f1[:, start:end]
            cur_data_f2 = data_f2[:, start:end]
            cur_randata_f1 = randata_f1[:, start:end]
            cur_randata_f2 = randata_f2[:, start:end]

            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            hidden_f1[s_id] = repackage_hidden(hidden_f1[s_id])
            hidden_f2[s_id] = repackage_hidden(hidden_f2[s_id])

            out_f1, out_reused_f1, hidden_f1[s_id] = parallel_model_f1(
                cur_data_f1, cur_randata_f1, hidden_f1[s_id])
            out_f2, out_reused_f2, hidden_f2[s_id] = parallel_model_f2(
                cur_data_f2, cur_randata_f2, hidden_f2[s_id])

            raw_loss_f1 = torch.mean(out_f1) - torch.log(
                torch.mean(torch.exp(out_reused_f1)))
            raw_loss_f2 = torch.mean(out_f2) - torch.log(
                torch.mean(torch.exp(out_reused_f2)))
            loss = [raw_loss_f1, raw_loss_f2]

            loss[0] *= args.small_batch_size / args.batch_size
            loss[1] *= args.small_batch_size / args.batch_size
            total_loss_fi[
                0] += raw_loss_f1.data * args.small_batch_size / args.batch_size
            total_loss_fi[
                1] += raw_loss_f2.data * args.small_batch_size / args.batch_size
            (-loss[0]).backward()  # for gradient ascent we use -loss
            (-loss[1]).backward()

            # optimizer grad check... note: comment this if 'nan' problem solved
            for j, p in enumerate(optimizer_f1.param_groups[0]['params']):
                ans = p.grad.data[torch.isnan(p.grad.data)]
                if ans.size()[0] > 0:
                    logging.info(
                        "some nan exists in optimizer_f1 parameter #{} !".
                        format(j))
                    raise KeyboardInterrupt
            for j, p in enumerate(optimizer_f2.param_groups[0]['params']):
                ans = p.grad.data[torch.isnan(p.grad.data)]
                if ans.size()[0] > 0:
                    logging.info(
                        "some nan exists in optimizer_f2 parameter #{} !".
                        format(j))
                    raise KeyboardInterrupt

            s_id += 1
            start = end
            end = start + args.small_batch_size

            gc.collect()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model_f1.parameters(), args.clip)
        torch.nn.utils.clip_grad_norm_(model_f2.parameters(), args.clip)

        optimizer_f1.step()
        optimizer_f2.step()

        #learning rate - back to normal
        optimizer_f1.param_groups[0]['lr'] = lr_tmp_f1
        optimizer_f2.param_groups[0]['lr'] = lr_tmp_f2

        if batch < 1:
            log_interval = np.ceil((len(train_data_f1) // args.bptt) / 5)
        if batch % log_interval == 0 and batch > 0:
            curr_loss_fi = [
                total_loss_fi[0].item() / args.log_interval,
                total_loss_fi[1].item() / args.log_interval
            ]
            elapsed = time.time() - start_time
            logging.info(
                '| epoch {:3d} | {:3d}/{:3d} batches | lr1 {:f} | lr2 {:f} | ms/batch {:5.4f} | '
                'loss F1 {:6.4f} | loss F2 {:6.4f} | loss {:6.4f}'.format(
                    epoch, batch,
                    len(train_data_f1) // args.bptt,
                    optimizer_f1.param_groups[0]['lr'],
                    optimizer_f2.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, curr_loss_fi[0],
                    curr_loss_fi[1], curr_loss_fi[1] - curr_loss_fi[0]))
            total_loss_fi = [0, 0]
            start_time = time.time()
        ###
        batch += 1
        i += seq_len