예제 #1
0
    def load_dataset(self, args):
        # Load dataset
        splits = ['train', 'valid']
        if data.has_binary_files(args.data, splits):
            dataset = data.load_dataset(args.data, splits, args.src_lang,
                                        args.trg_lang, args.fixed_max_len)
        else:
            dataset = data.load_raw_text_dataset(args.data, splits,
                                                 args.src_lang, args.trg_lang,
                                                 args.fixed_max_len)
        if args.src_lang is None or args.trg_lang is None:
            # record inferred languages in args, so that it's saved in checkpoints
            args.src_lang, args.trg_lang = dataset.src, dataset.dst

        print('| [{}] dictionary: {} types'.format(dataset.src,
                                                   len(dataset.src_dict)))
        print('| [{}] dictionary: {} types'.format(dataset.dst,
                                                   len(dataset.dst_dict)))

        for split in splits:
            print('| {} {} {} examples'.format(args.data, split,
                                               len(dataset.splits[split])))

        self.dataset = dataset
예제 #2
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    g_model_path = 'checkpoints/zhenwarm/generator.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")
    #
    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/myzhencli5'):
        os.makedirs('checkpoints/myzhencli5')
    checkpoints_path = 'checkpoints/myzhencli5/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(trainloader):

            # set training mode
            generator.train()
            discriminator.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch = generator(sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            # part II: train the discriminator
            if num_update % 5 == 0:
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = Variable(
                    torch.ones(
                        sample['target'].size(0)).float())  # 64 length vector

                with torch.no_grad():
                    sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = Variable(
                    torch.zeros(
                        sample['target'].size(0)).float())  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence,
                                              src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                # fake_disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1
                # true_disc_out = discriminator(src_sentence, true_sentence)
                #
                # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #
                # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (fake_acc + true_acc) / 2
                #
                # d_loss = fake_d_loss + true_d_loss
                if random.random() > 0.5:
                    fake_disc_out = discriminator(src_sentence, fake_sentence)
                    fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                              fake_labels)
                    fake_acc = torch.sum(
                        torch.round(fake_disc_out).squeeze(1) ==
                        fake_labels).float() / len(fake_labels)
                    d_loss = fake_d_loss
                    acc = fake_acc
                else:
                    true_disc_out = discriminator(src_sentence, true_sentence)
                    true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                              true_labels)
                    true_acc = torch.sum(
                        torch.round(true_disc_out).squeeze(1) ==
                        true_labels).float() / len(true_labels)
                    d_loss = true_d_loss
                    acc = true_acc

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

            if num_update % 10000 == 0:

                # validation
                # set validation mode
                generator.eval()
                discriminator.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss = g_criterion(out_batch, dev_trg_batch)
                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss / sample_size / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss, sample_size)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                        # discriminator validation
                        bsz = sample['target'].size(0)
                        src_sentence = sample['net_input']['src_tokens']
                        # train with half human-translation and half machine translation

                        true_sentence = sample['target']
                        true_labels = Variable(
                            torch.ones(sample['target'].size(0)).float())

                        with torch.no_grad():
                            sys_out_batch = generator(sample)

                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                        _, prediction = out_batch.topk(1)
                        prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                        fake_labels = Variable(
                            torch.zeros(sample['target'].size(0)).float())

                        fake_sentence = torch.reshape(
                            prediction, src_sentence.shape)  # 64 X 50
                        true_sentence = torch.reshape(true_sentence,
                                                      src_sentence.shape)
                        if use_cuda:
                            fake_labels = fake_labels.cuda()
                            true_labels = true_labels.cuda()

                        fake_disc_out = discriminator(src_sentence,
                                                      fake_sentence)  # 64 X 1
                        true_disc_out = discriminator(src_sentence,
                                                      true_sentence)

                        fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                                  fake_labels)
                        true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                                  true_labels)
                        d_loss = fake_d_loss + true_d_loss
                        fake_acc = torch.sum(
                            torch.round(fake_disc_out).squeeze(1) ==
                            fake_labels).float() / len(fake_labels)
                        true_acc = torch.sum(
                            torch.round(true_disc_out).squeeze(1) ==
                            true_labels).float() / len(true_labels)
                        acc = (fake_acc + true_acc) / 2
                        d_logging_meters['valid_acc'].update(acc)
                        d_logging_meters['valid_loss'].update(d_loss)
                        logging.debug(
                            f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                        )

                # torch.save(discriminator,
                #            open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill)

                # if d_logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = d_logging_meters['valid_loss'].avg
                #     torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill)

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
예제 #3
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

        # check checkpoints saving path
    if not os.path.exists('checkpoints/generator'):
        os.makedirs('checkpoints/generator')

    checkpoints_path = 'checkpoints/generator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['bsz'] = AverageMeter()  # sentences per batch
    logging_meters['update_times'] = AverageMeter()

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # Build model
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)

    # g_model_path = 'checkpoints/generator/numupdate1.4180668458302803.data.nll_105000.000.pt'
    # assert os.path.exists(g_model_path)
    # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    # model_dict = generator.state_dict()
    # model = torch.load(g_model_path)
    # pretrained_dict = model.state_dict()
    # # 1. filter out unnecessary keys
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # # 2. overwrite entries in the existing state dict
    # model_dict.update(pretrained_dict)
    # # 3. load the new state dict
    # generator.load_state_dict(model_dict)
    # print("pre-trained Generator loaded successfully!")

    if use_cuda:
        if len(args.gpuid) > 1:
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
    else:
        generator.cpu()

    print("Training generator...")

    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')

    optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(),
                                                      args.learning_rate)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf

    epoch_i = 1
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']
    num_update = 0
    # main training loop
    while lr > args.min_g_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (min(args.max_source_positions,
                                   generator.encoder.max_positions()),
                               min(args.max_target_positions,
                                   generator.decoder.max_positions()))

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_train,
            seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # set training mode

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            generator.train()
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            sys_out_batch = generator(sample)
            out_batch = sys_out_batch.contiguous().view(
                -1, sys_out_batch.size(-1))
            train_trg_batch = sample['target'].view(-1)
            loss = g_criterion(out_batch, train_trg_batch)
            sample_size = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss = loss.item() / sample_size / math.log(2)
            logging_meters['bsz'].update(nsentences)
            logging_meters['train_loss'].update(logging_loss, sample_size)
            logging.debug(
                "g loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format(
                    i, logging_meters['train_loss'].avg,
                    round(logging_meters['bsz'].avg),
                    optimizer.param_groups[0]['lr']))
            optimizer.zero_grad()
            loss.backward()

            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                if p.requires_grad:
                    p.grad.data.div_(sample_size)

            torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                           args.clip_norm)
            optimizer.step()

            num_update = num_update + 1
            if num_update % 5000 == 0:
                # validation -- this is a crude estimation because there might be some padding at the end
                max_positions_valid = (
                    generator.encoder.max_positions(),
                    generator.decoder.max_positions(),
                )

                # Initialize dataloader
                itr = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.max_sentences,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=args.
                    skip_invalid_size_inputs_valid_test,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )
                # set validation mode
                generator.eval()

                # reset meters
                for key, val in logging_meters.items():
                    if val is not None:
                        val.reset()
                with torch.no_grad():
                    for i, sample in enumerate(itr):
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))

                        val_trg_batch = sample['target'].view(-1)
                        loss = g_criterion(out_batch, val_trg_batch)

                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss.item() / sample_size / math.log(2)
                        logging_meters['valid_loss'].update(loss, sample_size)
                        logging.debug(
                            "g dev loss at batch {0}: {1:.3f}".format(
                                i, logging_meters['valid_loss'].avg))

                # update learning rate
                lr_scheduler.step(logging_meters['valid_loss'].avg)
                lr = optimizer.param_groups[0]['lr']

                logging.info(
                    "Average g loss value per instance is {0} at the end of epoch {1}"
                    .format(logging_meters['valid_loss'].avg, epoch_i))
                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        "numupdate{1}.data.nll_{0:.1f}.pt".format(
                            num_update, logging_meters['valid_loss'].avg),
                        'wb'))

                # if logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = logging_meters['valid_loss'].avg
                #     torch.save(generator.state_dict(), open(
                #         checkpoints_path + "best_gmodel.pt", 'wb'))

        epoch_i += 1
예제 #4
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    print("======printing args========")
    print(args)
    print("=================================")

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        print("Loading bin dataset")
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    else:
        print(f"Loading raw text dataset {args.data}")
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # try to load generator model
    g_model_path = 'checkpoints/generator/best_gmodel.pt'
    if not os.path.exists(g_model_path):
        print("Start training generator!")
        train_g(args, dataset)
    assert os.path.exists(g_model_path)
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(g_model_path)
    #print(f"First dict: {pretrained_dict}")
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    #print(f"Second dict: {pretrained_dict}")
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    #print(f"model dict: {model_dict}")
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    print("Generator has successfully loaded!")

    # try to load discriminator model
    d_model_path = 'checkpoints/discriminator/best_dmodel.pt'
    if not os.path.exists(d_model_path):
        print("Start training discriminator!")
        train_d(args, dataset)
    assert os.path.exists(d_model_path)
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    model_dict = discriminator.state_dict()
    pretrained_dict = torch.load(d_model_path)
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(model_dict)

    print("Discriminator has successfully loaded!")

    #return
    print("starting main training loop")

    torch.autograd.set_detect_anomaly(True)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/joint'):
        os.makedirs('checkpoints/joint')
    checkpoints_path = 'checkpoints/joint/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(size_average=False,
                                   ignore_index=dataset.dst_dict.pad(),
                                   reduce=True)
    d_criterion = torch.nn.BCEWithLogitsLoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        # seed = args.seed + epoch_i
        # torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(itr):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when rand > 50%
            rand = random.random()
            if rand >= 0.5:
                # policy gradient training
                generator.decoder.is_testing = True
                sys_out_batch, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
                with torch.no_grad():
                    n_i = sample['net_input']['src_tokens']
                    #print(f"net input:\n{n_i}, pred: \n{prediction}")
                    reward = discriminator(
                        sample['net_input']['src_tokens'],
                        prediction)  # dataset.dst_dict.pad())
                train_trg_batch = sample['target']
                #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}")
                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()

                # oracle valid
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
            else:
                # MLE training
                #print(f"printing sample: \n{sample}")
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()
            num_update += 1

            # part II: train the discriminator
            bsz = sample['target'].size(0)
            src_sentence = sample['net_input']['src_tokens']
            # train with half human-translation and half machine translation

            true_sentence = sample['target']
            true_labels = Variable(
                torch.ones(sample['target'].size(0)).float())

            with torch.no_grad():
                generator.decoder.is_testing = True
                _, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
            fake_sentence = prediction
            fake_labels = Variable(
                torch.zeros(sample['target'].size(0)).float())

            trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
            labels = torch.cat([true_labels, fake_labels], dim=0)

            indices = np.random.permutation(2 * bsz)
            trg_sentence = trg_sentence[indices][:bsz]
            labels = labels[indices][:bsz]

            if use_cuda:
                labels = labels.cuda()

            disc_out = discriminator(src_sentence,
                                     trg_sentence)  #, dataset.dst_dict.pad())
            #print(f"disc out: {disc_out.shape}, labels: {labels.shape}")
            #print(f"labels: {labels}")
            d_loss = d_criterion(disc_out, labels.long())
            acc = torch.sum(torch.Sigmoid()
                            (disc_out).round() == labels).float() / len(labels)
            d_logging_meters['train_acc'].update(acc)
            d_logging_meters['train_loss'].update(d_loss)
            # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg,
            #                                                                            d_logging_meters['train_acc'].avg,
            #                                                                            i))
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    sample['id'] = sample['id'].cuda()
                    sample['net_input']['src_tokens'] = sample['net_input'][
                        'src_tokens'].cuda()
                    sample['net_input']['src_lengths'] = sample['net_input'][
                        'src_lengths'].cuda()
                    sample['net_input']['prev_output_tokens'] = sample[
                        'net_input']['prev_output_tokens'].cuda()
                    sample['target'] = sample['target'].cuda()

                # generator validation
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("G dev loss at batch {0}: {1:.3f}".format(
                    i, g_logging_meters['valid_loss'].avg))

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = Variable(
                    torch.ones(sample['target'].size(0)).float())

                with torch.no_grad():
                    generator.decoder.is_testing = True
                    _, prediction, _ = generator(sample)
                    generator.decoder.is_testing = False
                fake_sentence = prediction
                fake_labels = Variable(
                    torch.zeros(sample['target'].size(0)).float())

                trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
                labels = torch.cat([true_labels, fake_labels], dim=0)

                indices = np.random.permutation(2 * bsz)
                trg_sentence = trg_sentence[indices][:bsz]
                labels = labels[indices][:bsz]

                if use_cuda:
                    labels = labels.cuda()

                disc_out = discriminator(src_sentence, trg_sentence,
                                         dataset.dst_dict.pad())
                d_loss = d_criterion(disc_out, labels)
                acc = torch.sum(torch.Sigmoid()(disc_out).round() ==
                                labels).float() / len(labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg,
                #                                                                     d_logging_meters['valid_acc'].avg, i))

        torch.save(generator,
                   open(
                       checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format(
                           g_logging_meters['valid_loss'].avg, epoch_i), 'wb'),
                   pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
예제 #5
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args.decoder_embed_dim,
                                  args.discriminator_hidden_size,
                                  args.discriminator_linear_size,
                                  args.discriminator_lin_dropout,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    def _calcualte_discriminator_loss(tf_scores, ar_scores):
        tf_loss = torch.log(tf_scores + 1e-6) * (-1)
        ar_loss = torch.log(1 - ar_scores + 1e-6) * (-1)
        return tf_loss + ar_loss

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/professorjp'):
        os.makedirs('checkpoints/professorjp')
    checkpoints_path = 'checkpoints/professorjp/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    # d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    # for p in discriminator.embed_src_tokens.parameters():
    #     p.requires_grad = False
    # for p in discriminator.embed_trg_tokens.parameters():
    #     p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(trainloader):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator
            # print("Policy Gradient Training")
            sys_out_batch_PG, p_PG, hidden_list_PG = generator(
                'PG', epoch_i, sample)  # 64 X 50 X 6632

            out_batch_PG = sys_out_batch_PG.contiguous().view(
                -1, sys_out_batch_PG.size(-1))  # (64 * 50) X 6632

            _, prediction = out_batch_PG.topk(1)
            # prediction = prediction.squeeze(1)  # 64*50 = 3200
            # prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape)  # 64 X 50

            with torch.no_grad():
                reward = discriminator(hidden_list_PG)  # 64 X 1

            train_trg_batch_PG = sample['target']  # 64 x 50

            pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG,
                                      reward, use_cuda)
            sample_size_PG = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']  # 64
            logging_loss_PG = pg_loss_PG / math.log(2)
            g_logging_meters['train_loss'].update(logging_loss_PG.item(),
                                                  sample_size_PG)
            logging.debug(
                f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
            )
            g_optimizer.zero_grad()
            pg_loss_PG.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                           args.clip_norm)
            g_optimizer.step()

            # print("MLE Training")

            sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator(
                "MLE", epoch_i, sample)

            out_batch_MLE = sys_out_batch_MLE.contiguous().view(
                -1, sys_out_batch_MLE.size(-1))  # (64 X 50) X 6632

            train_trg_batch_MLE = sample['target'].view(-1)  # 64*50 = 3200
            loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE)

            sample_size_MLE = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2)
            g_logging_meters['bsz'].update(nsentences)
            g_logging_meters['train_loss'].update(logging_loss_MLE,
                                                  sample_size_MLE)
            logging.debug(
                f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
            )
            # g_optimizer.zero_grad()
            loss_MLE.backward(retain_graph=True)
            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                # print(p.size())
                if p.requires_grad:
                    p.grad.data.div_(sample_size_MLE)
            torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                           args.clip_norm)
            g_optimizer.step()

            num_update += 1

            # part II: train the discriminator

            d_MLE = discriminator(hidden_list_MLE)
            d_PG = discriminator(hidden_list_PG)
            d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum()
            logging.debug(f"D training loss {d_loss} at batch {i}")

            d_optimizer.zero_grad()
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                           args.clip_norm)
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        valloader = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(valloader):

            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)

                # generator validation
                sys_out_batch_test, p_test, hidden_list_test = generator(
                    'test', epoch_i, sample)
                out_batch_test = sys_out_batch_test.contiguous().view(
                    -1, sys_out_batch_test.size(-1))  # (64 X 50) X 6632
                dev_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss_test = g_criterion(out_batch_test, dev_trg_batch)
                sample_size_test = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss_test = loss_test / sample_size_test / math.log(2)
                g_logging_meters['valid_loss'].update(loss_test,
                                                      sample_size_test)
                logging.debug(
                    f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                )

                # # discriminator validation
                # bsz = sample['target'].size(0)
                # src_sentence = sample['net_input']['src_tokens']
                # # train with half human-translation and half machine translation
                #
                # true_sentence = sample['target']
                # true_labels = torch.ones(sample['target'].size(0)).float()
                #
                # with torch.no_grad():
                #     sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample)
                #
                # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 X 50) X 6632
                #
                # _, prediction = out_batch.topk(1)
                # prediction = prediction.squeeze(1)  # 64 * 50 = 6632
                #
                # fake_labels = torch.zeros(sample['target'].size(0)).float()
                #
                # fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                #
                # if use_cuda:
                #     fake_labels = fake_labels.cuda()
                #
                # disc_out = discriminator(src_sentence, fake_sentence)
                # d_loss = d_criterion(disc_out.squeeze(1), fake_labels)
                # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # d_logging_meters['valid_acc'].update(acc)
                # d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug(
                #     f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}")

        torch.save(
            generator,
            open(
                checkpoints_path +
                f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt",
                'wb'),
            pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    if args.gpuid:
        cuda.set_device(args.gpuid[0])

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Build model
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)

    if use_cuda:
        generator.cuda()
    else:
        generator.cpu()

    optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(),
                                                      args.learning_rate)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf

    epoch_i = 1
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']
    # main training loop

    # added for write training loss
    f1 = open("train_loss", "a")

    while lr > args.min_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (min(args.max_source_positions,
                                   generator.encoder.max_positions()),
                               min(args.max_target_positions,
                                   generator.decoder.max_positions()))

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_train,
            seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )
        # set training mode
        generator.train()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            loss = generator(sample)
            sample_size = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss = loss.data / sample_size / math.log(2)
            logging_meters['bsz'].update(nsentences)
            logging_meters['train_loss'].update(logging_loss, sample_size)
            f1.write("{0}\n".format(logging_meters['train_loss'].avg))
            logging.debug(
                "loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format(
                    i, logging_meters['train_loss'].avg,
                    round(logging_meters['bsz'].avg),
                    optimizer.param_groups[0]['lr']))
            optimizer.zero_grad()
            loss.backward()

            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                if p.requires_grad:
                    p.grad.data.div_(sample_size)

            torch.nn.utils.clip_grad_norm(generator.parameters(),
                                          args.clip_norm)
            optimizer.step()

            # validation -- this is a crude estimation because there might be some padding at the end
            max_positions_valid = (
                generator.encoder.max_positions(),
                generator.decoder.max_positions(),
            )

        # Initialize dataloader
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=args.
            skip_invalid_size_inputs_valid_test,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )
        # set validation mode
        generator.eval()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)
                loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("dev loss at batch {0}: {1:.3f}".format(
                    i, logging_meters['valid_loss'].avg))

        # update learning rate
        lr_scheduler.step(logging_meters['valid_loss'].avg)
        lr = optimizer.param_groups[0]['lr']

        logging.info(
            "Average loss value per instance is {0} at the end of epoch {1}".
            format(logging_meters['valid_loss'].avg, epoch_i))
        torch.save(
            generator.state_dict(),
            open(
                args.model_file + "data.nll_{0:.3f}.epoch_{1}.pt".format(
                    logging_meters['valid_loss'].avg, epoch_i), 'wb'))

        if logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = logging_meters['valid_loss'].avg
            torch.save(generator.state_dict(),
                       open(args.model_file + "best_gmodel.pt", 'wb'))

        epoch_i += 1

    f1.close()
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['MLE_train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['PG_train_loss'] = AverageMeter()
    g_logging_meters['MLE_train_acc'] = AverageMeter()
    g_logging_meters['PG_train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['D_h_train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['D_s_train_loss'] = AverageMeter()
    d_logging_meters['D_h_train_acc'] = AverageMeter()
    d_logging_meters['D_s_train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0.1
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0.1
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator_h = Discriminator_h(args.decoder_embed_dim,
                                      args.discriminator_hidden_size,
                                      args.discriminator_linear_size,
                                      args.discriminator_lin_dropout,
                                      use_cuda=use_cuda)
    print("Discriminator_h loaded successfully!")
    discriminator_s = Discriminator_s(args,
                                      dataset.src_dict,
                                      dataset.dst_dict,
                                      use_cuda=use_cuda)
    print("Discriminator_s loaded successfully!")

    # Load generator model
    g_model_path = 'checkpoints/zhenwarm/genev.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")

    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator_s.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator_s.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    # # Load discriminatorH model
    # d_H_model_path = 'checkpoints/joint_warm/DH.pt'
    # assert os.path.exists(d_H_model_path)
    # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    # d_H_model_dict = discriminator_h.state_dict()
    # d_H_model = torch.load(d_H_model_path)
    # d_H_pretrained_dict = d_H_model.state_dict()
    # # 1. filter out unnecessary keys
    # d_H_pretrained_dict = {k: v for k, v in d_H_pretrained_dict.items() if k in d_H_model_dict}
    # # 2. overwrite entries in the existing state dict
    # d_H_model_dict.update(d_H_pretrained_dict)
    # # 3. load the new state dict
    # discriminator_h.load_state_dict(d_H_model_dict)
    # print("pre-trained Discriminator_H loaded successfully!")

    def _calcualte_discriminator_loss(tf_scores, ar_scores):
        tf_loss = torch.log(tf_scores + 1e-9) * (-1)
        ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1)
        return tf_loss + ar_loss

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator_h = torch.nn.DataParallel(discriminator_h).cuda()
            discriminator_s = torch.nn.DataParallel(discriminator_s).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator_h.cuda()
            discriminator_s.cuda()
    else:
        discriminator_h.cpu()
        discriminator_s.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/realmyzhenup10shrink07drop01new'):
        os.makedirs('checkpoints/realmyzhenup10shrink07drop01new')
    checkpoints_path = 'checkpoints/realmyzhenup10shrink07drop01new/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator_s word embedding (as Wu et al. do)
    for p in discriminator_s.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator_s.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer_h = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator_h.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    d_optimizer_s = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator_s.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode

        for i, sample in enumerate(trainloader):
            generator.train()
            discriminator_h.train()
            discriminator_s.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)
            train_MLE = 0
            train_PG = 0

            if random.random() > 0.5 and i != 0:
                train_MLE = 1
                # print("MLE Training")
                sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator(
                    "MLE", epoch_i, sample)

                out_batch_MLE = sys_out_batch_MLE.contiguous().view(
                    -1, sys_out_batch_MLE.size(-1))  # (64 X 50) X 6632

                train_trg_batch_MLE = sample['target'].view(-1)  # 64*50 = 3200
                loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE)

                sample_size_MLE = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(
                    2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['MLE_train_loss'].update(
                    logging_loss_MLE, sample_size_MLE)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['MLE_train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss_MLE.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size_MLE)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()
            else:
                train_PG = 1
                ## part I: use gradient policy method to train the generator
                # print("Policy Gradient Training")
                sys_out_batch_PG, p_PG, hidden_list_PG = generator(
                    'PG', epoch_i, sample)  # 64 X 50 X 6632

                out_batch_PG = sys_out_batch_PG.contiguous().view(
                    -1, sys_out_batch_PG.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch_PG.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50
                # if d_logging_meters['train_acc'].avg >= 0.75:
                with torch.no_grad():
                    reward = discriminator_s(sample['net_input']['src_tokens'],
                                             prediction)  # 64 X 1
                # else:
                #     reward = torch.ones(args.joint_batch_size)
                train_trg_batch_PG = sample['target']  # 64 x 50

                pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG,
                                          reward, use_cuda)
                sample_size_PG = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss_PG = pg_loss_PG / math.log(2)
                g_logging_meters['PG_train_loss'].update(
                    logging_loss_PG.item(), sample_size_PG)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                if g_logging_meters['PG_train_loss'].val < 0.5:
                    pg_loss_PG.backward()
                    torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                                   args.clip_norm)
                    g_optimizer.step()

            num_update += 1

            # if g_logging_meters["MLE_train_loss"].avg < 4:
            #  part II: train the discriminator
            # discriminator_h
            if num_update % 10 == 0:
                if g_logging_meters["PG_train_loss"].val < 2:
                    assert (train_MLE == 1) != (train_PG == 1)
                    if train_MLE == 1:
                        d_MLE = discriminator_h(hidden_list_MLE.detach())
                        M_loss = torch.log(d_MLE + 1e-9) * (-1)
                        h_d_loss = M_loss.sum()
                    elif train_PG == 1:
                        d_PG = discriminator_h(hidden_list_PG.detach())
                        P_loss = torch.log(1 - d_PG + 1e-9) * (-1)
                        h_d_loss = P_loss.sum()

                    # d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum()
                    logging.debug(f"D_h training loss {h_d_loss} at batch {i}")

                    d_optimizer_h.zero_grad()

                    h_d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        discriminator_h.parameters(), args.clip_norm)
                    d_optimizer_h.step()

                #discriminator_s
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output

                with torch.no_grad():
                    sys_out_batch, p, hidden_list = generator(
                        'PG', epoch_i, sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = torch.zeros(
                    sample['target'].size(0)).float()  # 64 length vector
                # true_labels = torch.ones(sample['target'].size(0)).float()  # 64 length vector
                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50

                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    # true_labels = true_labels.cuda()

                # if random.random() > 0.5:
                fake_disc_out = discriminator_s(src_sentence,
                                                fake_sentence)  # 64 X 1
                fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                          fake_labels)
                acc = torch.sum(
                    torch.round(fake_disc_out).squeeze(1) ==
                    fake_labels).float() / len(fake_labels)
                d_loss = fake_d_loss
                # else:
                #
                #     true_sentence = sample['target'].view(-1)  # 64*50 = 3200
                #
                #     true_sentence = torch.reshape(true_sentence, src_sentence.shape)
                #     true_disc_out = discriminator_s(src_sentence, true_sentence)
                #     acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                #     true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #     d_loss = true_d_loss

                # acc_fake = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # acc_true = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (acc_fake + acc_true) / 2
                # acc = acc_fake
                # d_loss = fake_d_loss + true_d_loss
                s_d_loss = d_loss
                d_logging_meters['D_s_train_acc'].update(acc)

                d_logging_meters['D_s_train_loss'].update(s_d_loss)
                logging.debug(
                    f"D_s training loss {d_logging_meters['D_s_train_loss'].avg:.3f}, acc {d_logging_meters['D_s_train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer_s.zero_grad()
                s_d_loss.backward()
                d_optimizer_s.step()

            if num_update % 10000 == 0:
                # validation
                # set validation mode
                print(
                    'validation and save+++++++++++++++++++++++++++++++++++++++++++++++'
                )
                generator.eval()
                discriminator_h.eval()
                discriminator_s.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch_test, p_test, hidden_list_test = generator(
                            'test', epoch_i, sample)
                        out_batch_test = sys_out_batch_test.contiguous().view(
                            -1,
                            sys_out_batch_test.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss_test = g_criterion(out_batch_test, dev_trg_batch)
                        sample_size_test = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss_test = loss_test / sample_size_test / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss_test, sample_size_test)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}w.sampling_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
예제 #8
0
def main(args):

    model_name = args.model_name
    assert model_name is not None
    if model_name == "gan":
        Model = LSTMModel
    elif model_name == "vae":
        Model = VarLSTMModel
    elif model_name == "mle":
        Model = LSTMModel
    else:
        raise Exception("Model name should be: gan|vae|mle")

    if len(args.gpuid) >= 1 and args.gpuid[0] >= 0:
        use_cuda = True
        cuda.set_device(args.gpuid[0])
        map_to = torch.device(f"cuda:{args.gpuid[0]}")
    else:
        use_cuda = False
        map_to = torch.device('cpu')

    # Load dataset
    # if args.replace_unk is None:
    if data.has_binary_files(args.data, ['test']):
        dataset = data.load_dataset(
            args.data,
            ['test'],
            args.src_lang,
            args.trg_lang,
        )
    else:
        dataset = data.load_raw_text_dataset(
            args.data,
            ['test'],
            args.src_lang,
            args.trg_lang,
        )

    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    print('| {} {} {} examples'.format(args.data, 'test',
                                       len(dataset.splits['test'])))

    # Set model parameters
    args.encoder_embed_dim = 128
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 128
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 128
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # Load model
    if args.model_file is None:
        g_model_path = 'checkpoint/VAE_2021-03-04 12:16:21/best_gmodel.pt'
    else:
        g_model_path = args.model_file

    def load_params():
        params = json.loads(
            open(os.path.join(os.path.dirname(g_model_path),
                              "params.json")).read())
        args.__dict__.update(params)

    load_params()

    assert os.path.exists(g_model_path), f"Path does not exist {g_model_path}"
    generator = Model(args,
                      dataset.src_dict,
                      dataset.dst_dict,
                      use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path, map_location=map_to)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    generator.eval()

    print("Generator loaded successfully!")

    if use_cuda > 0:
        generator.cuda()
    else:
        generator.cpu()

    max_positions = generator.encoder.max_positions()

    testloader = dataset.eval_dataloader(
        'test',
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
    )

    translator = SequenceGenerator(generator,
                                   beam_size=args.beam,
                                   stop_early=(not args.no_early_stop),
                                   normalize_scores=(not args.unnormalized),
                                   len_penalty=args.lenpen,
                                   unk_penalty=args.unkpen)

    if use_cuda:
        translator.cuda()

    with open('predictions.txt', 'w') as translation_writer:
        with open('real.txt', 'w') as ground_truth_writer:

            translations = translator.generate_batched_itr(
                testloader,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda)

            for sample_id, src_tokens, target_tokens, hypos in translations:
                # Process input and ground truth
                target_tokens = target_tokens.int().cpu()
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                # Process top predictions
                for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                    hypo_tokens = hypo['tokens'].int().cpu()
                    hypo_str = dataset.dst_dict.string(hypo_tokens,
                                                       args.remove_bpe)

                    hypo_str += '\n'
                    target_str += '\n'

                    # translation_writer.write(hypo_str.encode('utf-8'))
                    # ground_truth_writer.write(target_str.encode('utf-8'))
                    translation_writer.write(hypo_str)
                    ground_truth_writer.write(target_str)
예제 #9
0
def main(args):
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.DEBUG)

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    use_cuda = (torch.cuda.device_count() >= 1)

    # check checkpoints saving path
    if not os.path.exists('checkpoints/discriminator'):
        os.makedirs('checkpoints/discriminator')

    checkpoints_path = 'checkpoints/discriminator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['train_acc'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['valid_acc'] = AverageMeter()
    logging_meters['update_times'] = AverageMeter()

    # Build model
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)

    # Load generator
    assert os.path.exists(
        'checkpoints/generator/numupdate2.998810066080318.data.nll_245000.0.pt'
    )
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(
        'checkpoints/generator/numupdate2.998810066080318.data.nll_245000.0.pt'
    )
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            # generator = torch.nn.DataParallel(generator).cuda()
            generator.cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    criterion = torch.nn.CrossEntropyLoss()

    # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()),
    #                                                     args.d_learning_rate, momentum=args.momentum, nesterov=True)

    optimizer = torch.optim.RMSprop(
        filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the accuracy achieve the define value
    max_epoch = args.max_epoch or math.inf
    epoch_i = 1
    trg_acc = 0.82
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']

    # validation set data loader (only prepare once)
    train = prepare_training_data(args, dataset, 'train', generator, epoch_i,
                                  use_cuda)
    valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i,
                                  use_cuda)
    data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len)
    data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len)

    num_update = 0
    # main training loop
    while lr > args.min_d_lr and epoch_i <= max_epoch:
        discriminator.train()
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        if args.sample_without_replacement > 0 and epoch_i > 1:
            train = prepare_training_data(args, dataset, 'train', generator,
                                          epoch_i, use_cuda)
            data_train = DatasetProcessing(data=train,
                                           maxlen=args.fixed_max_len)

        # discriminator training dataloader
        train_loader = train_dataloader(data_train,
                                        batch_size=args.joint_batch_size,
                                        seed=seed,
                                        epoch=epoch_i,
                                        sort_by_source_size=False)

        valid_loader = eval_dataloader(data_valid,
                                       num_workers=4,
                                       batch_size=args.joint_batch_size)

        # set training mode

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(train_loader):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=use_cuda)

            disc_out = discriminator(sample['src_tokens'],
                                     sample['trg_tokens'])

            loss = criterion(disc_out, sample['labels'])
            _, prediction = F.softmax(disc_out, dim=1).topk(1)
            acc = torch.sum(
                prediction == sample['labels'].unsqueeze(1)).float() / len(
                    sample['labels'])

            logging_meters['train_acc'].update(acc.item())
            logging_meters['train_loss'].update(loss.item())
            logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                          format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg,
                                 optimizer.param_groups[0]['lr'], i))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(discriminator.parameters(),
                                          args.clip_norm)
            optimizer.step()

            # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc
            del disc_out, loss, prediction, acc

            num_update += 1
            # set validation mode
            if num_update % 5000 == 0:
                discriminator.eval()

                for i, sample in enumerate(valid_loader):
                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=use_cuda)

                        disc_out = discriminator(sample['src_tokens'],
                                                 sample['trg_tokens'])

                        loss = criterion(disc_out, sample['labels'])
                        _, prediction = F.softmax(disc_out, dim=1).topk(1)
                        acc = torch.sum(prediction == sample['labels'].
                                        unsqueeze(1)).float() / len(
                                            sample['labels'])

                        logging_meters['valid_acc'].update(acc.item())
                        logging_meters['valid_loss'].update(loss.item())
                        logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                                      format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg,
                                             optimizer.param_groups[0]['lr'], i))

                    del disc_out, loss, prediction, acc

                lr_scheduler.step(logging_meters['valid_loss'].avg)

                if logging_meters['valid_acc'].avg >= 0.70:
                    torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \
                               .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i))

                    if logging_meters['valid_loss'].avg < best_dev_loss:
                        best_dev_loss = logging_meters['valid_loss'].avg
                        torch.save(discriminator.state_dict(),
                                   checkpoints_path + "best_dmodel.pt")

        # pretrain the discriminator to achieve accuracy 82%
        # if logging
        epoch_i += 1