Beispiel #1
0
                                     batch_size=1,
                                     shuffle=True,
                                     **loader_kwargs)

test_loader = data_utils.DataLoader(BarleyBatches(train=False),
                                    batch_size=1,
                                    shuffle=False,
                                    **loader_kwargs)

print('Init Model')
if args.model == 'attention':
    model = Attention()
elif args.model == 'gated_attention':
    model = GatedAttention()
if args.cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(),
                       lr=args.lr,
                       betas=(0.9, 0.999),
                       weight_decay=args.reg)
writer = SummaryWriter()


def train(epoch):
    model.train()
    train_loss = 0.
    train_error = 0.
    y_hat = []
    y = []
    for batch_idx, (data, label) in enumerate(train_loader):
def main(reader, params):
    #from rcnn_attention import evaluator
    k_shot = params.k
    num_negative_bags = params.neg
    total_bags = k_shot + num_negative_bags
    result_lists = {}
    input_dim = 256 if params.dataset == 'omniglot' else 640

    for i, tensor_data in enumerate(reader.get_data()):
        if (i + 1) % 100 == 0:
            print('Evaluating problem number %d/%d' % (i + 1, params.eval_num))
        [feas, fea_boxes, fea_target_classes, fea_classes, imgs,
         target_class] = tensor_data[0:6]
        boxes_list = tensor_data[6:6 + total_bags]
        class_list = tensor_data[6 + total_bags:]
        bags = np.squeeze(feas)
        bag_labels = np.max(fea_target_classes, axis=1)
        input_labels = fea_target_classes.astype(np.int64)
        train_loader = data_utils.DataLoader(ImageBags(bags=bags,
                                                       labels=input_labels),
                                             batch_size=1,
                                             shuffle=True,
                                             **loader_kwargs)
        test_loader = data_utils.DataLoader(ImageBags(bags=bags,
                                                      labels=input_labels),
                                            batch_size=1,
                                            shuffle=False,
                                            **loader_kwargs)
        model = Attention(input_dim=input_dim)
        if params.cuda:
            model.cuda()
        optimizer = optim.Adam(model.parameters(),
                               lr=params.lr,
                               betas=(0.9, 0.999),
                               weight_decay=params.reg)

        def train(epoch):
            model.train()
            train_loss = 0.
            train_error = 0.
            for batch_idx, (data, label) in enumerate(train_loader):
                bag_label = label[0]
                if params.cuda:
                    data, bag_label = data.cuda(), bag_label.cuda()
                data, bag_label = Variable(data), Variable(bag_label)

                # reset gradients
                optimizer.zero_grad()
                # calculate loss and metrics
                loss, _ = model.calculate_objective(data, bag_label)
                train_loss += loss.data[0]
                #error, _ = model.calculate_classification_error(data, bag_label)
                #train_error += error
                # backward pass
                loss.backward()
                # step
                optimizer.step()

            train_loss /= len(train_loader)
            #print('epoch: {}, loss: {}'.format(epoch, train_loss))
            #train_error /= len(train_loader)

        def test():
            model.eval()
            test_loss = 0.
            test_error = 0.
            num_success = 0
            scores = np.zeros_like(fea_classes[:params.k])
            for batch_idx, (data, label) in enumerate(test_loader):
                bag_label = label[0]
                instance_labels = label[1]
                if params.cuda:
                    data, bag_label = data.cuda(), bag_label.cuda()
                data, bag_label = Variable(data), Variable(bag_label)
                loss, attention_weights = model.calculate_objective(
                    data, bag_label)
                test_loss += loss.data[0]
                #error, predicted_label = model.calculate_classification_error(data, bag_label)
                #test_error += error
                if batch_idx < params.k:
                    scores[batch_idx] = attention_weights.cpu().data.numpy()[0]
                    #argmax_pred = np.argmax(attention_weights.cpu().data.numpy()[0])
                    #val = instance_labels.numpy()[0].tolist()[argmax_pred]
                    #num_success += val
                    #print('batch idx: {}, val: {}'.format(batch_idx, val))
            #print('scores: ', scores)
            res = {
                'boxes': fea_boxes[:params.k],
                'classes': np.ones_like(fea_classes[:params.k]),
                'scores': scores,
                'class_agnostic': True
            }
            return res

        gt = {}
        gt['boxes'] = boxes_list[:params.k]
        gt['classes'] = class_list[:params.k]
        gt['target_class'] = target_class
        for epoch in range(1, args.epochs + 1):
            train(epoch)
        res = test()
        result_dict = {'groundtruth': gt, 'atnmil': res}
        from rcnn_attention import evaluator
        evaluator._postprocess_result_dict(result_dict)
        result_dict.pop('groundtruth')
        add_results(result_dict, result_lists)
        if i + 1 == params.eval_num:
            break
    metrics = {}
    from rcnn_attention import eval_util
    for method, result_list in result_lists.items():
        m = eval_util.evaluate_coloc_results(result_list, None)
        metrics[method] = m
    for k, v in metrics.items():
        print('{}: {}'.format(k, v))
Beispiel #3
0
def train(config_path, resume=True):

    # Load the parameters
    param_dict, rep_param_dict = load_params(config_path)

    # use cuda flag
    use_cuda = True
    """
    the tranining directory
    """
    # load data
    TRAIN_DIR01 = "{}/MQ2007/S1/".format(param_dict["data_base_path"])
    TRAIN_DIR02 = "{}/MQ2007/S2/".format(param_dict["data_base_path"])
    TRAIN_DIR03 = "{}/MQ2007/S3/".format(param_dict["data_base_path"])
    TRAIN_DIR04 = "{}/MQ2007/S4/".format(param_dict["data_base_path"])
    TRAIN_DIR05 = "{}/MQ2007/S5/".format(param_dict["data_base_path"])

    TEST_DIR01 = '{}/MQ2007/S1/'.format(param_dict["data_base_path"])
    TEST_DIR02 = '{}/MQ2007/S2/'.format(param_dict["data_base_path"])
    TEST_DIR03 = '{}/MQ2007/S3/'.format(param_dict["data_base_path"])
    TEST_DIR04 = '{}/MQ2007/S4/'.format(param_dict["data_base_path"])
    TEST_DIR05 = '{}/MQ2007/S5/'.format(param_dict["data_base_path"])

    train_files01 = glob.glob("{}/data0.pkl".format(TRAIN_DIR01))
    train_files02 = glob.glob("{}/data0.pkl".format(TRAIN_DIR02))
    train_files03 = glob.glob("{}/data0.pkl".format(TRAIN_DIR03))
    train_files04 = glob.glob("{}/data0.pkl".format(TRAIN_DIR04))
    train_files05 = glob.glob("{}/data0.pkl".format(TRAIN_DIR05))

    test_files01 = glob.glob("{}/testdata0.pkl".format(TEST_DIR01))
    test_files02 = glob.glob("{}/testdata0.pkl".format(TEST_DIR02))
    test_files03 = glob.glob("{}/testdata0.pkl".format(TEST_DIR03))
    test_files04 = glob.glob("{}/testdata0.pkl".format(TEST_DIR04))
    test_files05 = glob.glob("{}/testdata0.pkl".format(TEST_DIR05))

    fold = param_dict["fold"]
    model_base_path = param_dict['model_base_path']
    model_name_str = param_dict['model_name_str']
    q_len = param_dict["q_len"]
    d_len = param_dict["d_len"]

    if fold == 1:
        train_files = train_files01 + train_files02 + train_files03
        test_files = test_files04[0]  # a path list ['/...'] only take the str
        rel_path = '{}/{}/tmp/test/S4.qrels'.format(model_base_path,
                                                    model_name_str)
    elif fold == 2:
        train_files = train_files02 + train_files03 + train_files04
        test_files = test_files05[0]
        rel_path = '{}/{}/tmp/test/S5.qrels'.format(model_base_path,
                                                    model_name_str)
    elif fold == 3:
        train_files = train_files03 + train_files04 + train_files05
        test_files = test_files01[0]
        rel_path = '{}/{}/tmp/test/S1.qrels'.format(model_base_path,
                                                    model_name_str)
    elif fold == 4:
        train_files = train_files04 + train_files05 + train_files01
        test_files = test_files02[0]
        rel_path = '{}/{}/tmp/test/S2.qrels'.format(model_base_path,
                                                    model_name_str)
    elif fold == 5:
        train_files = train_files05 + train_files01 + train_files02
        test_files = test_files03[0]
        rel_path = '{}/{}/tmp/test/S3.qrels'.format(model_base_path,
                                                    model_name_str)
    else:
        raise ValueError("wrong fold num {}".format(fold))
    """
    Build the model
    """
    emb_size = param_dict['emb_size']
    num_heads = param_dict['num_heads']
    kernel_size = rep_param_dict['kernel_size']
    filt_size = rep_param_dict['filt_size']
    vocab_size = param_dict['vocab_size']
    output_dim = rep_param_dict['output_dim']
    hidden_size = param_dict['hidden_size']
    batch_size = param_dict['batch_size']
    preemb = param_dict['preemb']
    emb_path = param_dict['emb_path']
    hinge_margin = param_dict['hinge_margin']

    model = Attention(emb_size=emb_size,
                      query_length=q_len,
                      doc_length=d_len,
                      num_heads=num_heads,
                      kernel_size=kernel_size,
                      filter_size=filt_size,
                      vocab_size=vocab_size,
                      dropout=0.0,
                      qrep_dim=output_dim,
                      hidden_size=hidden_size,
                      batch_size=batch_size,
                      preemb=preemb,
                      emb_path=emb_path)

    if use_cuda:
        model.cuda()
    # optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=param_dict['learning_rate'],
                           betas=(param_dict['beta1'], param_dict['beta2']),
                           weight_decay=param_dict['alpha'])
    # loss func
    loss = nn.MarginRankingLoss(margin=hinge_margin, size_average=True)
    # experiment
    print("Experiment")

    if resume == False:
        f_log = open(
            '{}/{}/logs/training_log.txt'.format(model_base_path,
                                                 model_name_str), 'w+', 1)
        valid_log = open(
            '{}/{}/logs/valid_log.txt'.format(model_base_path, model_name_str),
            'w+', 1)
    else:
        f_log = open(
            '{}/{}/logs/training_log.txt'.format(model_base_path,
                                                 model_name_str), 'a+', 1)
        valid_log = open(
            '{}/{}/logs/valid_log.txt'.format(model_base_path, model_name_str),
            'a+', 1)

    # model_file
    model_file = '{}/{}/saves/model_file'.format(model_base_path,
                                                 model_name_str)
    """
    TRAINING
    """

    # define the parameters
    n_epoch = param_dict['n_epoch']
    # init best validation MAP value
    best_MAP = 0.0
    best_NDCG1 = 0.0
    batch_count_tr = 0
    # restore saved parameter if resume_training is true
    if resume == True:
        model_file = '{}/{}/saves/model_file'.format(model_base_path,
                                                     model_name_str)
        model.load_state_dict(torch.load(model_file))
        with open(
                '{}/{}/saves/best_MAP.pkl'.format(model_base_path,
                                                  model_name_str),
                'rb') as f_MAP:
            best_MAP = pickle.load(f_MAP)
        print("loaded model, and resume training now")

    for epoch in range(1, n_epoch + 1):
        '''load_data'''
        for f in train_files:
            data = load_dataset(f)
            print("loaded {}".format(f))
            '''prepare_data'''
            [Q, D_pos, D_neg, L] = pair_data_generator(data, q_len)
            valid_data = load_dataset(test_files)
            ''' shuffle data'''
            train_data = list_shuffle(Q, D_pos, D_neg, L)
            '''training func'''

            num_batch = len(train_data[0]) // batch_size
            for batch_count in range(num_batch):
                Q = train_data[0][batch_size * batch_count:batch_size *
                                  (batch_count + 1)]
                D_pos = train_data[1][batch_size * batch_count:batch_size *
                                      (batch_count + 1)]
                D_neg = train_data[2][batch_size * batch_count:batch_size *
                                      (batch_count + 1)]
                L = train_data[3][batch_size * batch_count:batch_size *
                                  (batch_count + 1)]
                if use_cuda:
                    Q = Variable(torch.LongTensor(
                        pad_batch_list(Q, max_len=q_len, padding_id=0)),
                                 requires_grad=False).cuda()
                    D_pos = Variable(torch.LongTensor(
                        pad_batch_list(D_pos, max_len=d_len, padding_id=0)),
                                     requires_grad=False).cuda()
                    D_neg = Variable(torch.LongTensor(
                        pad_batch_list(D_neg, max_len=d_len, padding_id=0)),
                                     requires_grad=False).cuda()
                    L = Variable(torch.FloatTensor(L),
                                 requires_grad=False).cuda()
                else:
                    Q = Variable(torch.LongTensor(
                        pad_batch_list(Q, max_len=q_len, padding_id=0)),
                                 requires_grad=False)
                    D_pos = Variable(torch.LongTensor(
                        pad_batch_list(D_pos, max_len=d_len, padding_id=0)),
                                     requires_grad=False)
                    D_neg = Variable(torch.LongTensor(
                        pad_batch_list(D_neg, max_len=d_len, padding_id=0)),
                                     requires_grad=False)
                    L = Variable(torch.FloatTensor(L), requires_grad=False)

                # run on this batch
                optimizer.zero_grad()
                t1 = time.time()

                q_mask, d_pos_mask, d_neg_mask = model.generate_mask(
                    Q, D_pos, D_neg)
                """
                need to do the modification i the model.py
                """
                S_pos, S_neg = model(Q, D_pos, D_neg, q_mask, d_pos_mask,
                                     d_neg_mask)
                Loss = hinge_loss(S_pos, S_neg, 1.0)
                Loss.backward()
                optimizer.step()
                t2 = time.time()
                batch_count_tr += 1
                print("epoch {} batch {} training cost: {} using {}s" \
                .format(epoch, batch_count+1, Loss.data[0], t2-t1))
                f_log.write("epoch {} batch {} training cost: {}, using {}s".
                            format(epoch, batch_count + 1, Loss.data[0], t2 -
                                   t1) + '\n')
                """
                evaluate part
                """
                if batch_count_tr % 20 == 0:
                    if valid_data is not None:
                        MAP, NDCGs = evaluate(config_path,
                                              model,
                                              valid_data,
                                              rel_path,
                                              mode="valid")
                        print(MAP, NDCGs)
                        valid_log.write(
                            "epoch {}, batch {}, MAP: {}, NDCGs: {} {} {} {}".
                            format(epoch + 1, batch_count + 1, MAP,
                                   NDCGs[1][0], NDCGs[1][1], NDCGs[1][2],
                                   NDCGs[1][3]))
                        if MAP > best_MAP:  # save this best model
                            best_MAP = MAP
                            with open(
                                    '{}/{}/saves/best_MAP.pkl'.format(
                                        model_base_path, model_name_str),
                                    'wb') as f_MAP:
                                pickle.dump(best_MAP, f_MAP)
                            # save model params after several epoch
                            model_file = '{}/{}/saves/model_file'.format(
                                model_base_path, model_name_str)
                            torch.save(model.state_dict(), model_file)
                            print("successfully saved model to the path {}".
                                  format(model_file))

                        valid_log.write("{} {} {} {}".format(
                            NDCGs[1][0], NDCGs[1][1], NDCGs[1][2],
                            NDCGs[1][3]))
                        valid_log.write(" MAP: {}".format(MAP))
                        valid_log.write('\n')
    f_log.close()
    valid_log.close()
Beispiel #4
0
    def train(self, src_emb, tgt_emb):
        params = self.params
        suffix_str = params.suffix_str
        # Load data
        if not os.path.exists(params.data_dir):
            raise "Data path doesn't exists: %s" % params.data_dir

        en = src_emb
        it = tgt_emb

        params = _get_eval_params(params)
        evaluator = eval.Evaluator(params,
                                   src_emb.weight.data,
                                   tgt_emb.weight.data,
                                   use_cuda=True)

        if params.context > 0:
            try:
                knn_list = pickle.load(
                    open('full_knn_list_' + suffix_str + '.pkl', 'rb'))
            except FileNotFoundError:
                knn_list = get_knn_embedding(params,
                                             src_emb,
                                             suffix_str,
                                             context=params.context,
                                             method='csls',
                                             use_cuda=True)
            self.knn_emb = convert_to_embeddings(knn_list, use_cuda=True)

        for _ in range(params.num_random_seeds):

            # Create models
            g = Generator(input_size=params.g_input_size,
                          hidden_size=params.g_hidden_size,
                          output_size=params.g_output_size,
                          hyperparams=get_hyperparams(params, disc=False))
            d = Discriminator(input_size=params.d_input_size,
                              hidden_size=params.d_hidden_size,
                              output_size=params.d_output_size,
                              hyperparams=get_hyperparams(params, disc=True))
            a = Attention(atype=params.atype,
                          input_size=2 * params.g_input_size,
                          hidden_size=params.a_hidden_size)
            r_p = RankPredictor(
                input_size=params.g_output_size,
                output_size=int(
                    np.floor(np.log(params.most_frequent_sampling_size)) + 1),
                hidden_size=params.d_hidden_size // 4,
                leaky_slope=params.leaky_slope)

            if params.initialize_prev_best == 1 and params.context in [0, 2]:
                prev_best_model_file_path = os.path.join(
                    params.model_dir, params.prev_best_model_fname)
                g.load_state_dict(
                    torch.load(prev_best_model_file_path, map_location='cpu'))
                print(g.map1.weight.data)

            if params.seed > 0:
                seed = params.seed
            else:
                seed = random.randint(0, 1000)
            # init_xavier(g)
            # init_xavier(d)
            self.initialize_exp(seed)

            # Define loss function and optimizers
            loss_fn = torch.nn.BCELoss()
            r_p_loss_fn = torch.nn.CrossEntropyLoss()
            d_optimizer = optim.SGD(d.parameters(), lr=params.d_learning_rate)
            g_optimizer = optim.SGD(g.parameters(), lr=params.g_learning_rate)
            r_p_optimizer = optim.SGD(g.parameters(),
                                      lr=params.g_learning_rate)

            if params.atype in ['mlp', 'bilinear']:
                a_optimizer = optim.SGD(a.parameters(),
                                        lr=params.g_learning_rate)

            if torch.cuda.is_available():
                # Move the network and the optimizer to the GPU
                g = g.cuda()
                d = d.cuda()
                a = a.cuda()
                r_p = r_p.cuda()
                loss_fn = loss_fn.cuda()
                r_p_loss_fn = r_p_loss_fn.cuda()

                # Regularization loss
                reg_loss = 0
                for i, p in enumerate(g.parameters()):
                    if i > 0:
                        break
                    pred = p.transpose(0, 1)[300:, :]
                    reg_loss += pred.norm(2)
                factor = 1e-2

                reg_loss = reg_loss.cuda()
            # true_dict = get_true_dict(params.data_dir)
            d_acc_epochs = []
            g_loss_epochs = []

            # logs for plotting later
            log_file = open(
                "log_{}_{}_{}.txt".format(self.params.src_lang,
                                          self.params.tgt_lang, seed),
                "w")  # Being overwritten in every loop, not really required
            log_file.write("epoch, dis_loss, dis_acc, g_loss, acc, acc_new\n")

            try:
                for epoch in range(params.num_epochs):

                    d_losses = []
                    g_losses = []
                    rank_losses = []
                    hit = 0
                    total = 0
                    start_time = timer()

                    for mini_batch in range(
                            0,
                            params.iters_in_epoch // params.mini_batch_size):
                        # W_orig = g.map1.weight.data
                        # print(W_orig)

                        for d_index in range(params.d_steps):
                            d_optimizer.zero_grad()  # Reset the gradients
                            d.train()
                            input, output = self.get_batch_data_fast(
                                en, it, g, a, detach=True)
                            pred = d(input)
                            d_loss = loss_fn(pred, output)
                            d_loss.backward(
                            )  # compute/store gradients, but don't change params
                            d_losses.append(d_loss.data.cpu().numpy())
                            discriminator_decision = pred.data.cpu().numpy()
                            hit += np.sum(
                                discriminator_decision[:params.mini_batch_size]
                                >= 0.5)
                            hit += np.sum(
                                discriminator_decision[params.mini_batch_size:]
                                < 0.5)
                            d_optimizer.step(
                            )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                            # Clip weights
                            _clip(d, params.clip_value)

                            sys.stdout.write(
                                "[%d/%d] :: Discriminator Loss: %f \r" %
                                (mini_batch, params.iters_in_epoch //
                                 params.mini_batch_size,
                                 np.asscalar(np.mean(d_losses))))
                            sys.stdout.flush()

                        total += 2 * params.mini_batch_size * params.d_steps

                        for g_index in range(params.g_steps):
                            # 2. Train G on D's response (but DO NOT train D on these labels)
                            g_optimizer.zero_grad()
                            d.eval()

                            if params.use_rank_predictor > 0:
                                input, output, true_ranks = self.get_batch_data_fast(
                                    en,
                                    it,
                                    g,
                                    a,
                                    detach=False,
                                    use_rank_predictor=True)
                            else:
                                input, output = self.get_batch_data_fast(
                                    en, it, g, a, detach=False)

                            pred = d(input)
                            g_loss = loss_fn(pred, 1 - output)

                            g_loss += factor * reg_loss

                            if params.use_rank_predictor > 0:
                                g_loss.backward(retain_graph=True)
                            else:
                                g_loss.backward(retain_graph=True)
                            g_optimizer.step()  # Only optimizes G's parameters
                            if params.atype in ['mlp', 'bilinear']:
                                a_optimizer.step()
                            g_losses.append(g_loss.data.cpu().numpy())

                            if params.use_rank_predictor > 0:
                                # First half of input are the transformed embeddings
                                fake_input = input[:len(input) // 2]
                                rank_predictions = r_p(fake_input)
                                rank_loss = r_p_loss_fn(
                                    rank_predictions, true_ranks)
                                rank_loss.backward()
                                r_p_optimizer.step()
                                rank_losses.append(
                                    rank_loss.data.cpu().numpy())

                            # Orthogonalize
                            if params.context == 1:
                                pass
                                # for i, p in enumerate(g.parameters()):
                                #     print("%d: " % i)
                                #     print(p.shape)
                                # W_orig = g.map1.weight.data
                                # print(W_orig)
                                # print(W_orig)
                                # W_top = W_orig[:300, :300]
                                # W_bottom = W_orig[300:, 300:]
                                # print(W_top)
                                # print(W_bottom)
                                # self.orthogonalize(g.map2.weight.data)
                            else:
                                self.orthogonalize(g.map1.weight.data)

                            if params.use_rank_predictor > 0:
                                sys.stdout.write(
                                    "[%d/%d] ::                                     Generator Loss: %f , Rank Loss: %f \r"
                                    % (mini_batch, params.iters_in_epoch //
                                       params.mini_batch_size,
                                       np.asscalar(np.mean(g_losses)),
                                       np.asscalar(np.mean(rank_losses))))
                            else:
                                sys.stdout.write(
                                    "[%d/%d] ::                                     Generator Loss: %f \r"
                                    % (mini_batch, params.iters_in_epoch //
                                       params.mini_batch_size,
                                       np.asscalar(np.mean(g_losses))))
                            sys.stdout.flush()

                    d_acc_epochs.append(hit / total)
                    g_loss_epochs.append(np.asscalar(np.mean(g_losses)))
                    print(
                        "Epoch {} : Discriminator Loss: {:.5f}, Discriminator Accuracy: {:.5f}, Generator Loss: {:.5f}, Time elapsed {:.2f} mins"
                        .format(epoch, np.asscalar(np.mean(d_losses)),
                                hit / total, np.asscalar(np.mean(g_losses)),
                                (timer() - start_time) / 60))

                    # lr decay
                    g_optim_state = g_optimizer.state_dict()
                    old_lr = g_optim_state['param_groups'][0]['lr']
                    g_optim_state['param_groups'][0]['lr'] = max(
                        old_lr * params.lr_decay, params.lr_min)
                    g_optimizer.load_state_dict(g_optim_state)
                    print("Changing the learning rate: {} -> {}".format(
                        old_lr, g_optim_state['param_groups'][0]['lr']))
                    d_optim_state = d_optimizer.state_dict()
                    d_optim_state['param_groups'][0]['lr'] = max(
                        d_optim_state['param_groups'][0]['lr'] *
                        params.lr_decay, params.lr_min)
                    d_optimizer.load_state_dict(d_optim_state)

                    if (epoch + 1) % params.print_every == 0:
                        # No need for discriminator weights
                        # torch.save(d.state_dict(), 'discriminator_weights_en_es_{}.t7'.format(epoch))
                        if params.context > 0:
                            indices = torch.arange(
                                params.top_frequent_words).type(
                                    torch.LongTensor)
                            indices = to_cuda(indices, use_cuda=True)
                            all_precisions = evaluator.get_all_precisions(
                                g(
                                    construct_input(self.knn_emb,
                                                    indices,
                                                    en,
                                                    a,
                                                    atype=params.atype,
                                                    context=params.context,
                                                    use_cuda=True)).data)
                        else:
                            all_precisions = evaluator.get_all_precisions(
                                g(src_emb.weight).data)
                        #print(json.dumps(all_precisions))
                        p_1 = all_precisions['validation']['adv'][
                            'without-ref']['nn'][1]
                        p_1_new = all_precisions['validation-new']['adv'][
                            'without-ref']['nn'][1]
                        log_file.write(
                            "{},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f}\n".format(
                                epoch + 1,
                                np.asscalar(np.mean(d_losses)), hit / total,
                                np.asscalar(np.mean(g_losses)), p_1, p_1_new))
                        #log_file.write(str(all_precisions) + "\n")
                        # Saving generator weights

                        torch.save(
                            g.state_dict(), 'generator_weights_' + suffix_str +
                            '_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(
                                seed, epoch, params.g_learning_rate, p_1))
                        if params.atype in ['mlp', 'bilinear']:
                            torch.save(
                                a.state_dict(),
                                'generator_weights_' + suffix_str +
                                '_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(
                                    seed, epoch, params.a_learning_rate, p_1))

                # Save the plot for discriminator accuracy and generator loss
                fig = plt.figure()
                plt.plot(range(0, params.num_epochs),
                         d_acc_epochs,
                         color='b',
                         label='discriminator')
                plt.plot(range(0, params.num_epochs),
                         g_loss_epochs,
                         color='r',
                         label='generator')
                plt.ylabel('accuracy/loss')
                plt.xlabel('epochs')
                plt.legend()
                fig.savefig('d_g.png')

            except KeyboardInterrupt:
                print("Interrupted.. saving model !!!")
                torch.save(g.state_dict(), 'g_model_interrupt.t7')
                torch.save(d.state_dict(), 'd_model_interrupt.t7')
                if params.atype in ['mlp', 'bilinear']:
                    torch.save(a.state_dict(), 'a_model_interrupt.t7')
                log_file.close()
                exit()
            log_file.close()
        return g