Example #1
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    g_model_path = 'checkpoints/zhenwarm/generator.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")
    #
    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/myzhencli5'):
        os.makedirs('checkpoints/myzhencli5')
    checkpoints_path = 'checkpoints/myzhencli5/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(trainloader):

            # set training mode
            generator.train()
            discriminator.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch = generator(sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            # part II: train the discriminator
            if num_update % 5 == 0:
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = Variable(
                    torch.ones(
                        sample['target'].size(0)).float())  # 64 length vector

                with torch.no_grad():
                    sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = Variable(
                    torch.zeros(
                        sample['target'].size(0)).float())  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence,
                                              src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                # fake_disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1
                # true_disc_out = discriminator(src_sentence, true_sentence)
                #
                # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #
                # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (fake_acc + true_acc) / 2
                #
                # d_loss = fake_d_loss + true_d_loss
                if random.random() > 0.5:
                    fake_disc_out = discriminator(src_sentence, fake_sentence)
                    fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                              fake_labels)
                    fake_acc = torch.sum(
                        torch.round(fake_disc_out).squeeze(1) ==
                        fake_labels).float() / len(fake_labels)
                    d_loss = fake_d_loss
                    acc = fake_acc
                else:
                    true_disc_out = discriminator(src_sentence, true_sentence)
                    true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                              true_labels)
                    true_acc = torch.sum(
                        torch.round(true_disc_out).squeeze(1) ==
                        true_labels).float() / len(true_labels)
                    d_loss = true_d_loss
                    acc = true_acc

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

            if num_update % 10000 == 0:

                # validation
                # set validation mode
                generator.eval()
                discriminator.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss = g_criterion(out_batch, dev_trg_batch)
                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss / sample_size / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss, sample_size)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                        # discriminator validation
                        bsz = sample['target'].size(0)
                        src_sentence = sample['net_input']['src_tokens']
                        # train with half human-translation and half machine translation

                        true_sentence = sample['target']
                        true_labels = Variable(
                            torch.ones(sample['target'].size(0)).float())

                        with torch.no_grad():
                            sys_out_batch = generator(sample)

                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                        _, prediction = out_batch.topk(1)
                        prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                        fake_labels = Variable(
                            torch.zeros(sample['target'].size(0)).float())

                        fake_sentence = torch.reshape(
                            prediction, src_sentence.shape)  # 64 X 50
                        true_sentence = torch.reshape(true_sentence,
                                                      src_sentence.shape)
                        if use_cuda:
                            fake_labels = fake_labels.cuda()
                            true_labels = true_labels.cuda()

                        fake_disc_out = discriminator(src_sentence,
                                                      fake_sentence)  # 64 X 1
                        true_disc_out = discriminator(src_sentence,
                                                      true_sentence)

                        fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                                  fake_labels)
                        true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                                  true_labels)
                        d_loss = fake_d_loss + true_d_loss
                        fake_acc = torch.sum(
                            torch.round(fake_disc_out).squeeze(1) ==
                            fake_labels).float() / len(fake_labels)
                        true_acc = torch.sum(
                            torch.round(true_disc_out).squeeze(1) ==
                            true_labels).float() / len(true_labels)
                        acc = (fake_acc + true_acc) / 2
                        d_logging_meters['valid_acc'].update(acc)
                        d_logging_meters['valid_loss'].update(d_loss)
                        logging.debug(
                            f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                        )

                # torch.save(discriminator,
                #            open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill)

                # if d_logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = d_logging_meters['valid_loss'].avg
                #     torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill)

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
Example #2
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    print("======printing args========")
    print(args)
    print("=================================")

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        print("Loading bin dataset")
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    else:
        print(f"Loading raw text dataset {args.data}")
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # try to load generator model
    g_model_path = 'checkpoints/generator/best_gmodel.pt'
    if not os.path.exists(g_model_path):
        print("Start training generator!")
        train_g(args, dataset)
    assert os.path.exists(g_model_path)
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(g_model_path)
    #print(f"First dict: {pretrained_dict}")
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    #print(f"Second dict: {pretrained_dict}")
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    #print(f"model dict: {model_dict}")
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    print("Generator has successfully loaded!")

    # try to load discriminator model
    d_model_path = 'checkpoints/discriminator/best_dmodel.pt'
    if not os.path.exists(d_model_path):
        print("Start training discriminator!")
        train_d(args, dataset)
    assert os.path.exists(d_model_path)
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    model_dict = discriminator.state_dict()
    pretrained_dict = torch.load(d_model_path)
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(model_dict)

    print("Discriminator has successfully loaded!")

    #return
    print("starting main training loop")

    torch.autograd.set_detect_anomaly(True)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/joint'):
        os.makedirs('checkpoints/joint')
    checkpoints_path = 'checkpoints/joint/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(size_average=False,
                                   ignore_index=dataset.dst_dict.pad(),
                                   reduce=True)
    d_criterion = torch.nn.BCEWithLogitsLoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        # seed = args.seed + epoch_i
        # torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(itr):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when rand > 50%
            rand = random.random()
            if rand >= 0.5:
                # policy gradient training
                generator.decoder.is_testing = True
                sys_out_batch, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
                with torch.no_grad():
                    n_i = sample['net_input']['src_tokens']
                    #print(f"net input:\n{n_i}, pred: \n{prediction}")
                    reward = discriminator(
                        sample['net_input']['src_tokens'],
                        prediction)  # dataset.dst_dict.pad())
                train_trg_batch = sample['target']
                #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}")
                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()

                # oracle valid
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
            else:
                # MLE training
                #print(f"printing sample: \n{sample}")
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()
            num_update += 1

            # part II: train the discriminator
            bsz = sample['target'].size(0)
            src_sentence = sample['net_input']['src_tokens']
            # train with half human-translation and half machine translation

            true_sentence = sample['target']
            true_labels = Variable(
                torch.ones(sample['target'].size(0)).float())

            with torch.no_grad():
                generator.decoder.is_testing = True
                _, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
            fake_sentence = prediction
            fake_labels = Variable(
                torch.zeros(sample['target'].size(0)).float())

            trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
            labels = torch.cat([true_labels, fake_labels], dim=0)

            indices = np.random.permutation(2 * bsz)
            trg_sentence = trg_sentence[indices][:bsz]
            labels = labels[indices][:bsz]

            if use_cuda:
                labels = labels.cuda()

            disc_out = discriminator(src_sentence,
                                     trg_sentence)  #, dataset.dst_dict.pad())
            #print(f"disc out: {disc_out.shape}, labels: {labels.shape}")
            #print(f"labels: {labels}")
            d_loss = d_criterion(disc_out, labels.long())
            acc = torch.sum(torch.Sigmoid()
                            (disc_out).round() == labels).float() / len(labels)
            d_logging_meters['train_acc'].update(acc)
            d_logging_meters['train_loss'].update(d_loss)
            # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg,
            #                                                                            d_logging_meters['train_acc'].avg,
            #                                                                            i))
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    sample['id'] = sample['id'].cuda()
                    sample['net_input']['src_tokens'] = sample['net_input'][
                        'src_tokens'].cuda()
                    sample['net_input']['src_lengths'] = sample['net_input'][
                        'src_lengths'].cuda()
                    sample['net_input']['prev_output_tokens'] = sample[
                        'net_input']['prev_output_tokens'].cuda()
                    sample['target'] = sample['target'].cuda()

                # generator validation
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("G dev loss at batch {0}: {1:.3f}".format(
                    i, g_logging_meters['valid_loss'].avg))

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = Variable(
                    torch.ones(sample['target'].size(0)).float())

                with torch.no_grad():
                    generator.decoder.is_testing = True
                    _, prediction, _ = generator(sample)
                    generator.decoder.is_testing = False
                fake_sentence = prediction
                fake_labels = Variable(
                    torch.zeros(sample['target'].size(0)).float())

                trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
                labels = torch.cat([true_labels, fake_labels], dim=0)

                indices = np.random.permutation(2 * bsz)
                trg_sentence = trg_sentence[indices][:bsz]
                labels = labels[indices][:bsz]

                if use_cuda:
                    labels = labels.cuda()

                disc_out = discriminator(src_sentence, trg_sentence,
                                         dataset.dst_dict.pad())
                d_loss = d_criterion(disc_out, labels)
                acc = torch.sum(torch.Sigmoid()(disc_out).round() ==
                                labels).float() / len(labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg,
                #                                                                     d_logging_meters['valid_acc'].avg, i))

        torch.save(generator,
                   open(
                       checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format(
                           g_logging_meters['valid_loss'].avg, epoch_i), 'wb'),
                   pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(
            args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(
            args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0.3
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0.3
    args.bidirectional = False

    generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator_h = Discriminator_h(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda)
    print("Discriminator_h loaded successfully!")
    discriminator_s = Discriminator_s(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    print("Discriminator_s loaded successfully!")

    def _calcualte_discriminator_loss(tf_scores, ar_scores):
        tf_loss = torch.log(tf_scores + 1e-9) * (-1)
        ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1)
        return tf_loss + ar_loss

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator_h = torch.nn.DataParallel(discriminator_h).cuda()
            discriminator_s = torch.nn.DataParallel(discriminator_s).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator_h.cuda()
            discriminator_s.cuda()
    else:
        discriminator_h.cpu()
        discriminator_s.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/professor2'):
        os.makedirs('checkpoints/professor2')
    checkpoints_path = 'checkpoints/professor2/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True)

    # fix discriminator_h word embedding (as Wu et al. do)
    for p in discriminator_s.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator_s.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(lambda x: x.requires_grad,
                                                                 generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer_h = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad,
                                                                 discriminator_h.parameters()),
                                                          args.d_learning_rate,
                                                          momentum=args.momentum,
                                                          nesterov=True)

    d_optimizer_s = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad,
                                                                 discriminator_s.parameters()),
                                                          args.d_learning_rate,
                                                          momentum=args.momentum,
                                                          nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator_h.train()
        discriminator_s.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer)

        for i, sample in enumerate(trainloader):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator
            # print("Policy Gradient Training")
            sys_out_batch_PG, p_PG, hidden_list_PG = generator('PG', epoch_i, sample)  # 64 X 50 X 6632

            out_batch_PG = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 * 50) X 6632

            _, prediction = out_batch_PG.topk(1)
            prediction = prediction.squeeze(1)  # 64*50 = 3200
            prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape)  # 64 X 50

            with torch.no_grad():
                reward = discriminator_s(sample['net_input']['src_tokens'], prediction)  # 64 X 1

            train_trg_batch_PG = sample['target']  # 64 x 50

            pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda)
            sample_size_PG = sample['target'].size(0) if args.sentence_avg else sample['ntokens']  # 64
            logging_loss_PG = pg_loss_PG / math.log(2)
            g_logging_meters['train_loss'].update(logging_loss_PG.item(), sample_size_PG)
            logging.debug(
                f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}")
            g_optimizer.zero_grad()
            pg_loss_PG.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm)
            g_optimizer.step()

            # print("MLE Training")
            sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator("MLE", epoch_i, sample)

            out_batch_MLE = sys_out_batch_MLE.contiguous().view(-1, sys_out_batch_MLE.size(-1))  # (64 X 50) X 6632

            train_trg_batch_MLE = sample['target'].view(-1)  # 64*50 = 3200
            loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE)

            sample_size_MLE = sample['target'].size(0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2)
            g_logging_meters['bsz'].update(nsentences)
            g_logging_meters['train_loss'].update(logging_loss_MLE, sample_size_MLE)
            logging.debug(
                f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}")
            g_optimizer.zero_grad()
            loss_MLE.backward(retain_graph=True)
            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                # print(p.size())
                if p.requires_grad:
                    p.grad.data.div_(sample_size_MLE)
            torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm)
            g_optimizer.step()

            num_update += 1


            #  part II: train the discriminator

            # discriminator_h
            if num_update % 5 == 0:

                d_MLE = discriminator_h(hidden_list_MLE)
                d_PG = discriminator_h(hidden_list_PG)
                d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum()
                logging.debug(f"D_h training loss {d_loss} at batch {i}")

                d_optimizer_h.zero_grad()
                d_loss.backward()
                torch.nn.utils.clip_grad_norm_(discriminator_h.parameters(), args.clip_norm)
                d_optimizer_h.step()




                #discriminator_s
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input']['src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = torch.ones(sample['target'].size(0)).float()  # 64 length vector
                with torch.no_grad():
                    sys_out_batch, p, hidden_list = generator('MLE', epoch_i, sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = torch.zeros(sample['target'].size(0)).float()  # 64 length vector

                fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence, src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                fake_disc_out = discriminator_s(src_sentence, fake_sentence)  # 64 X 1
                true_disc_out = discriminator_s(src_sentence, true_sentence)

                fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)

                acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)

                d_loss = fake_d_loss + true_d_loss

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D_s training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}")
                d_optimizer_s.zero_grad()
                d_loss.backward()
                d_optimizer_s.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator_h.eval()
        discriminator_s.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        valloader = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(valloader):

            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)

                # generator validation
                sys_out_batch_test, p_test, hidden_list_test = generator('test', epoch_i, sample)
                out_batch_test = sys_out_batch_test.contiguous().view(-1, sys_out_batch_test.size(-1))  # (64 X 50) X 6632
                dev_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss_test = g_criterion(out_batch_test, dev_trg_batch)
                sample_size_test = sample['target'].size(0) if args.sentence_avg else sample['ntokens']
                loss_test = loss_test / sample_size_test / math.log(2)
                g_logging_meters['valid_loss'].update(loss_test, sample_size_test)
                logging.debug(f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}")

                # # discriminator_h validation
                # bsz = sample['target'].size(0)
                # src_sentence = sample['net_input']['src_tokens']
                # # train with half human-translation and half machine translation
                # true_sentence = sample['target']
                # true_labels = torch.ones(sample['target'].size(0)).float()
                # with torch.no_grad():
                #     sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample)
                #
                # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 X 50) X 6632
                # _, prediction = out_batch.topk(1)
                # prediction = prediction.squeeze(1)  # 64 * 50 = 6632
                # fake_labels = torch.zeros(sample['target'].size(0)).float()
                # fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                # if use_cuda:
                #     fake_labels = fake_labels.cuda()
                # disc_out = discriminator_h(src_sentence, fake_sentence)
                # d_loss = d_criterion(disc_out.squeeze(1), fake_labels)
                # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # d_logging_meters['valid_acc'].update(acc)
                # d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug(
                #     f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}")

        torch.save(generator,
                   open(checkpoints_path + f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt",
                        'wb'), pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
Example #4
0
def main(options):

    use_cuda = (len(options.gpuid) >= 1)
    # if options.gpuid:
    #   cuda.set_device(options.gpuid[0])

    src_train, src_dev, src_test, src_vocab = torch.load(
        open(options.data_file + "." + options.src_lang, 'rb'))
    trg_train, trg_dev, trg_test, trg_vocab = torch.load(
        open(options.data_file + "." + options.trg_lang, 'rb'))

    batched_train_src, batched_train_src_mask, sort_index = utils.tensor.advanced_batchize(
        src_train, options.batch_size, src_vocab.stoi["<blank>"])
    batched_train_trg, batched_train_trg_mask = utils.tensor.advanced_batchize_no_sort(
        trg_train, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)
    batched_dev_src, batched_dev_src_mask, sort_index = utils.tensor.advanced_batchize(
        src_dev, options.batch_size, src_vocab.stoi["<blank>"])
    batched_dev_trg, batched_dev_trg_mask = utils.tensor.advanced_batchize_no_sort(
        trg_dev, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)

    print "preprocessing batched data..."
    processed_src = list()
    processed_trg = list()
    processed_src_mask = list()
    processed_trg_mask = list()
    for batch_i in range(len(batched_train_src)):
        if batched_train_src[batch_i].size(
                0) <= 35 and batched_train_trg[batch_i].size(0) <= 35:
            processed_src.append(batched_train_src[batch_i])
            processed_trg.append(batched_train_trg[batch_i])
            processed_src_mask.append(batched_train_src_mask[batch_i])
            processed_trg_mask.append(batched_train_trg_mask[batch_i])

    batched_train_src = processed_src
    batched_train_trg = processed_trg
    batched_train_src_mask = processed_src_mask
    batched_train_trg_mask = processed_trg_mask

    processed_src = list()
    processed_trg = list()
    processed_src_mask = list()
    processed_trg_mask = list()
    for batch_i in range(len(batched_dev_src)):
        if batched_dev_src[batch_i].size(
                0) <= 35 and batched_dev_trg[batch_i].size(0) <= 35:
            processed_src.append(batched_dev_src[batch_i])
            processed_trg.append(batched_dev_trg[batch_i])
            processed_src_mask.append(batched_dev_src_mask[batch_i])
            processed_trg_mask.append(batched_dev_trg_mask[batch_i])

    batched_dev_src = processed_src
    batched_dev_trg = processed_trg
    batched_dev_src_mask = processed_src_mask
    batched_dev_trg_mask = processed_trg_mask

    del processed_src, processed_trg, processed_trg_mask, processed_src_mask

    trg_vocab_size = len(trg_vocab)
    src_vocab_size = len(src_vocab)
    word_emb_size = 50
    hidden_size = 1024

    nmt = NMT(src_vocab_size,
              trg_vocab_size,
              word_emb_size,
              hidden_size,
              src_vocab,
              trg_vocab,
              attn_model="general",
              use_cuda=True)
    discriminator = Discriminator(src_vocab_size,
                                  trg_vocab_size,
                                  word_emb_size,
                                  src_vocab,
                                  trg_vocab,
                                  use_cuda=True)

    if use_cuda > 0:
        #nmt = torch.nn.DataParallel(nmt,device_ids=options.gpuid).cuda()
        nmt.cuda()
        #discriminator = torch.nn.DataParallel(discriminator,device_ids=options.gpuid).cuda()
        discriminator.cuda()
    else:
        nmt.cpu()
        discriminator.cpu()

    criterion_g = torch.nn.NLLLoss().cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda()

    # Configure optimization
    optimizer_g = eval("torch.optim." + options.optimizer)(
        nmt.parameters(), options.learning_rate)
    optimizer_d = eval("torch.optim." + options.optimizer)(
        discriminator.parameters(), options.learning_rate)

    # main training loop
    f1 = open("train_loss", "a")
    f2 = open("dev_loss", "a")
    last_dev_avg_loss = float("inf")
    for epoch_i in range(options.epochs):
        logging.info("At {0}-th epoch.".format(epoch_i))
        # srange generates a lazy sequence of shuffled range

        train_loss_g = 0.0
        train_loss_d = 0.0
        train_loss_g_nll = 0.0
        train_loss_g_ce = 0.0
        train_loss_nll_batch_num = 0
        train_loss_ce_batch_num = 0
        for i, batch_i in enumerate(utils.rand.srange(len(batched_train_src))):
            if i == 1500:
                break
            # if i==5:
            #   break
            train_src_batch = Variable(batched_train_src[batch_i]
                                       )  # of size (src_seq_len, batch_size)
            train_trg_batch = Variable(batched_train_trg[batch_i]
                                       )  # of size (src_seq_len, batch_size)
            train_src_mask = Variable(batched_train_src_mask[batch_i])
            train_trg_mask = Variable(batched_train_trg_mask[batch_i])
            if use_cuda:
                train_src_batch = train_src_batch.cuda()
                train_trg_batch = train_trg_batch.cuda()
                train_src_mask = train_src_mask.cuda()
                train_trg_mask = train_trg_mask.cuda()

            # train discriminator
            sys_out_batch = nmt(train_src_batch, train_trg_batch,
                                True).detach()
            _, predict_batch = sys_out_batch.topk(1)
            del _
            predict_batch = predict_batch.squeeze(2)
            real_dis_label_out = discriminator(train_src_batch,
                                               train_trg_batch, True)
            fake_dis_label_out = discriminator(train_src_batch, predict_batch,
                                               True)
            optimizer_d.zero_grad()
            loss_d_real = criterion(
                real_dis_label_out,
                Variable(
                    torch.ones(options.batch_size *
                               len(options.gpuid)).long()).cuda())
            loss_d_real.backward()
            loss_d_fake = criterion(
                fake_dis_label_out,
                Variable(
                    torch.zeros(options.batch_size *
                                len(options.gpuid)).long()).cuda())
            #loss_d_fake.backward(retain_graph=True)
            loss_d_fake.backward()
            loss_d = loss_d_fake.data[0] + loss_d_real.data[0]
            del loss_d_fake, loss_d_real
            logging.debug("D loss at batch {0}: {1}".format(i, loss_d))
            f1.write("D train loss at batch {0}: {1}\n".format(i, loss_d))
            optimizer_d.step()

            if use_cuda > 0:
                sys_out_batch = sys_out_batch.cuda()
                train_trg_batch = train_trg_batch.cuda()
            else:
                sys_out_batch = sys_out_batch.cpu()
                train_trg_batch = train_trg_batch.cpu()

            # train nmt
            sys_out_batch = nmt(train_src_batch, train_trg_batch, True)
            _, predict_batch = sys_out_batch.topk(1)
            predict_batch = predict_batch.squeeze(2)
            fake_dis_label_out = discriminator(train_src_batch, predict_batch,
                                               True)
            if random.random() > 0.5:
                train_trg_mask = train_trg_mask.view(-1)
                train_trg_batch = train_trg_batch.view(-1)
                train_trg_batch = train_trg_batch.masked_select(train_trg_mask)
                train_trg_mask = train_trg_mask.unsqueeze(1).expand(
                    len(train_trg_mask), trg_vocab_size)
                sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
                sys_out_batch = sys_out_batch.masked_select(
                    train_trg_mask).view(-1, trg_vocab_size)
                loss_g = criterion_g(sys_out_batch, train_trg_batch)
                train_loss_g_nll += loss_g
                train_loss_nll_batch_num += 1
                f1.write("G train NLL loss at batch {0}: {1}\n".format(
                    i, loss_g.data[0]))
            else:
                loss_g = criterion(
                    fake_dis_label_out,
                    Variable(
                        torch.ones(options.batch_size *
                                   len(options.gpuid)).long()).cuda())
                train_loss_g_ce += loss_g
                train_loss_ce_batch_num += 1
                f1.write("G train CE loss at batch {0}: {1}\n".format(
                    i, loss_g.data[0]))

            logging.debug("G loss at batch {0}: {1}".format(i, loss_g.data[0]))

            optimizer_g.zero_grad()
            loss_g.backward()

            # # gradient clipping
            torch.nn.utils.clip_grad_norm(nmt.parameters(), 5.0)
            optimizer_g.step()

            train_loss_d += loss_d
        train_avg_loss_g_nll = train_loss_g_nll / train_loss_nll_batch_num
        train_avg_loss_g_ce = train_loss_g_ce / train_loss_ce_batch_num
        train_avg_loss_d = train_loss_d / len(train_src_batch)
        logging.info(
            "G TRAIN Average NLL loss value per instance is {0} at the end of epoch {1}"
            .format(train_avg_loss_g_nll, epoch_i))
        logging.info(
            "G TRAIN Average CE loss value per instance is {0} at the end of epoch {1}"
            .format(train_avg_loss_g_ce, epoch_i))
        logging.info(
            "D TRAIN Average loss value per instance is {0} at the end of epoch {1}"
            .format(train_avg_loss_d, epoch_i))

        # validation -- this is a crude esitmation because there might be some paddings at the end
        # dev_loss_g_nll = 0.0
        # dev_loss_g_ce = 0.0
        # dev_loss_d = 0.0

        # for batch_i in range(len(batched_dev_src)):
        #   dev_src_batch = Variable(batched_dev_src[batch_i], volatile=True)
        #   dev_trg_batch = Variable(batched_dev_trg[batch_i], volatile=True)
        #   dev_src_mask = Variable(batched_dev_src_mask[batch_i], volatile=True)
        #   dev_trg_mask = Variable(batched_dev_trg_mask[batch_i], volatile=True)
        #   if use_cuda:
        #     dev_src_batch = dev_src_batch.cuda()
        #     dev_trg_batch = dev_trg_batch.cuda()
        #     dev_src_mask = dev_src_mask.cuda()
        #     dev_trg_mask = dev_trg_mask.cuda()

        #   sys_out_batch = nmt(dev_src_batch, dev_trg_batch, False).detach()
        #   _,predict_batch = sys_out_batch.topk(1)
        #   predict_batch = predict_batch.squeeze(2)
        #   real_dis_label_out = discriminator(dev_src_batch, dev_trg_batch, True).detach()
        #   fake_dis_label_out = discriminator(dev_src_batch, predict_batch, True).detach()

        #   if use_cuda > 0:
        #     sys_out_batch = sys_out_batch.cuda()
        #     dev_trg_batch = dev_trg_batch.cuda()
        #   else:
        #     sys_out_batch = sys_out_batch.cpu()
        #     dev_trg_batch = dev_trg_batch.cpu()

        #   dev_trg_mask = dev_trg_mask.view(-1)
        #   dev_trg_batch = dev_trg_batch.view(-1)
        #   dev_trg_batch = dev_trg_batch.masked_select(dev_trg_mask)
        #   dev_trg_mask = dev_trg_mask.unsqueeze(1).expand(len(dev_trg_mask), trg_vocab_size)
        #   sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
        #   sys_out_batch = sys_out_batch.masked_select(dev_trg_mask).view(-1, trg_vocab_size)
        #   loss_g_nll = criterion_g(sys_out_batch, dev_trg_batch)
        #   loss_g_ce = criterion(fake_dis_label_out, Variable(torch.ones(options.batch_size*len(options.gpuid)).long(),volatile=True).cuda())
        #   loss_d = criterion(real_dis_label_out, Variable(torch.ones(options.batch_size*len(options.gpuid)).long(),volatile=True).cuda()) + criterion(fake_dis_label_out, Variable(torch.zeros(options.batch_size*len(options.gpuid)).long(),volatile=True).cuda())
        #   logging.debug("G dev NLL loss at batch {0}: {1}".format(batch_i, loss_g_nll.data[0]))
        #   logging.debug("G dev CE loss at batch {0}: {1}".format(batch_i, loss_g_ce.data[0]))
        #   f2.write("G dev NLL loss at batch {0}: {1}\n".format(batch_i, loss_g_nll.data[0]))
        #   f2.write("G dev CE loss at batch {0}: {1}\n".format(batch_i, loss_g_ce.data[0]))
        #   logging.debug("D dev loss at batch {0}: {1}".format(batch_i, loss_d.data[0]))
        #   f2.write("D dev loss at batch {0}: {1}\n".format(batch_i, loss_d.data[0]))
        #   dev_loss_g_nll += loss_g_nll
        #   dev_loss_g_ce += loss_g_ce
        #   dev_loss_d += loss_d
        # dev_avg_loss_g_nll = dev_loss_g_nll / len(batched_dev_src)
        # dev_avg_loss_g_ce = dev_loss_g_ce / len(batched_dev_src)
        # dev_avg_loss_d = dev_loss_d / len(batched_dev_src)
        # logging.info("G DEV Average NLL loss value per instance is {0} at the end of epoch {1}".format(dev_avg_loss_g_nll.cpu().data[0], epoch_i))
        # logging.info("G DEV Average CE loss value per instance is {0} at the end of epoch {1}".format(dev_avg_loss_g_ce.cpu().data[0], epoch_i))
        # logging.info("D DEV Average loss value per instance is {0} at the end of epoch {1}".format(dev_avg_loss_d.data[0], epoch_i))
        # # if (last_dev_avg_loss - dev_avg_loss).data[0] < options.estop:
        # #   logging.info("Early stopping triggered with threshold {0} (previous dev loss: {1}, current: {2})".format(epoch_i, last_dev_avg_loss.data[0], dev_avg_loss.data[0]))
        # #   break
    torch.save(nmt,
               open(
                   "nmt.nll_{0:.2f}.epoch_{1}".format(
                       train_avg_loss_g_nll.cpu().data[0], epoch_i), 'wb'),
               pickle_module=dill)
    torch.save(discriminator,
               open(
                   "discriminator.nll_{0:.2f}.epoch_{1}".format(
                       train_avg_loss_d.data[0], epoch_i), 'wb'),
               pickle_module=dill)
    f1.close()
    f2.close()
def train_d(args, dataset):
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.DEBUG)

    use_cuda = (torch.cuda.device_count() >= 1)

    # check checkpoints saving path
    if not os.path.exists('checkpoints/discriminator'):
        os.makedirs('checkpoints/discriminator')

    checkpoints_path = 'checkpoints/discriminator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['train_acc'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['valid_acc'] = AverageMeter()
    logging_meters['update_times'] = AverageMeter()

    # Build model
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)

    # Load generator
    assert os.path.exists('checkpoints/generator/best_gmodel.pt')
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load('checkpoints/generator/best_gmodel.pt')
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            # generator = torch.nn.DataParallel(generator).cuda()
            generator.cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    criterion = torch.nn.CrossEntropyLoss()

    # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()),
    #                                                     args.d_learning_rate, momentum=args.momentum, nesterov=True)

    optimizer = torch.optim.RMSprop(
        filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the accuracy achieve the define value
    max_epoch = args.max_epoch or math.inf
    epoch_i = 1
    trg_acc = 0.82
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']

    # validation set data loader (only prepare once)
    train = prepare_training_data(args, dataset, 'train', generator, epoch_i,
                                  use_cuda)
    valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i,
                                  use_cuda)
    data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len)
    data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len)

    # main training loop
    while lr > args.min_d_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        if args.sample_without_replacement > 0 and epoch_i > 1:
            train = prepare_training_data(args, dataset, 'train', generator,
                                          epoch_i, use_cuda)
            data_train = DatasetProcessing(data=train,
                                           maxlen=args.fixed_max_len)

        # discriminator training dataloader
        train_loader = train_dataloader(data_train,
                                        batch_size=args.joint_batch_size,
                                        seed=seed,
                                        epoch=epoch_i,
                                        sort_by_source_size=False)

        valid_loader = eval_dataloader(data_valid,
                                       num_workers=4,
                                       batch_size=args.joint_batch_size)

        # set training mode
        discriminator.train()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(train_loader):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=use_cuda)

            disc_out = discriminator(sample['src_tokens'],
                                     sample['trg_tokens'])

            loss = criterion(disc_out, sample['labels'])
            _, prediction = F.softmax(disc_out, dim=1).topk(1)
            acc = torch.sum(
                prediction == sample['labels'].unsqueeze(1)).float() / len(
                    sample['labels'])

            logging_meters['train_acc'].update(acc.item())
            logging_meters['train_loss'].update(loss.item())
            logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                          format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg,
                                 optimizer.param_groups[0]['lr'], i))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(discriminator.parameters(),
                                          args.clip_norm)
            optimizer.step()

            # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc
            del disc_out, loss, prediction, acc

        # set validation mode
        discriminator.eval()

        for i, sample in enumerate(valid_loader):
            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=use_cuda)

                disc_out = discriminator(sample['src_tokens'],
                                         sample['trg_tokens'])

                loss = criterion(disc_out, sample['labels'])
                _, prediction = F.softmax(disc_out, dim=1).topk(1)
                acc = torch.sum(
                    prediction == sample['labels'].unsqueeze(1)).float() / len(
                        sample['labels'])

                logging_meters['valid_acc'].update(acc.item())
                logging_meters['valid_loss'].update(loss.item())
                logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                              format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg,
                                     optimizer.param_groups[0]['lr'], i))

            del disc_out, loss, prediction, acc

        lr_scheduler.step(logging_meters['valid_loss'].avg)

        if logging_meters['valid_acc'].avg >= 0.70:
            torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \
                       .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i))

            if logging_meters['valid_loss'].avg < best_dev_loss:
                best_dev_loss = logging_meters['valid_loss'].avg
                torch.save(discriminator.state_dict(),
                           checkpoints_path + "best_dmodel.pt")

        # pretrain the discriminator to achieve accuracy 82%
        if logging_meters['valid_acc'].avg >= trg_acc:
            return

        epoch_i += 1