Exemplo n.º 1
0
def train(args, controller, task, epoch_itr):  # #revise-task 7
    """Train the model for one epoch."""
    # Update parameters every N batches, CORE scaling method
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )

    itr = iterators.GroupedIterator(itr, update_freq)

    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf

    loop = enumerate(progress, start=epoch_itr.iterations_in_epoch)

    for i, samples in loop:
        log_output = controller.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats

        stats = get_training_stats(controller)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second and updates-per-second calculation
        if i == 0:
            controller.get_meter('wps').reset()
            controller.get_meter('ups').reset()

        num_updates = controller.get_num_updates()

        if num_updates >= max_update:
            break
Exemplo n.º 2
0
 def init_meters(self, args):
     self.meters = OrderedDict()
     self.meters['train_loss'] = AverageMeter()
     self.meters['train_nll_loss'] = AverageMeter()
     self.meters['valid_loss'] = AverageMeter()
     self.meters['valid_nll_loss'] = AverageMeter()
     self.meters['wps'] = TimeMeter()       # words per second
     self.meters['ups'] = TimeMeter()       # updates per second
     self.meters['wpb'] = AverageMeter()    # words per batch
     self.meters['bsz'] = AverageMeter()    # sentences per batch
     self.meters['gnorm'] = AverageMeter()  # gradient norm
     self.meters['clip'] = AverageMeter()   # % of updates clipped
     self.meters['wall'] = TimeMeter()      # wall time in seconds
     self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
Exemplo n.º 3
0
def main():
    use_cuda = args.use_cuda

    train_data = UnlabeledContact(data=args.data_dir)
    print('Number of samples: {}'.format(len(train_data)))
    trainloader = DataLoader(train_data, batch_size=args.batch_size)

    # Contact matrices are 21x21
    input_size = 441
    img_height = 21
    img_width = 21

    vae = AutoEncoder(code_size=20,
                      imgsize=input_size,
                      height=img_height,
                      width=img_width)
    criterion = nn.BCEWithLogitsLoss()

    if use_cuda:
        #vae = nn.DataParallel(vae)
        vae = vae.cuda()  #.half()
        criterion = criterion.cuda()

    optimizer = optim.SGD(vae.parameters(), lr=0.01)

    clock = AverageMeter(name='clock32single', rank=0)
    epoch_loss = 0
    total_loss = 0
    end = time.time()
    for epoch in range(15):
        for batch_idx, data in enumerate(trainloader):
            inputs = data['cont_matrix']
            inputs = inputs.resize_(args.batch_size, 1, 21, 21)
            inputs = inputs.float()
            if use_cuda:
                inputs = inputs.cuda()  #.half()
            inputs = Variable(inputs)
            optimizer.zero_grad()
            output, code = vae(inputs)
            loss = criterion(output, inputs)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.data[0]

            clock.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.data[0]))

    clock.save(
        path=
        '/home/ygx/libraries/mds/molecules/molecules/conv_autoencoder/runtimes'
    )
Exemplo n.º 4
0
def main():
    use_cuda = args.use_cuda

    train_data = UnlabeledContact(data=args.data_dir)
    print('Number of samples: {}'.format(len(train_data)))
    trainloader = DataLoader(train_data, batch_size=args.batch_size)

    # Contact matrices are 21x21
    input_size = 441

    encoder = Encoder(input_size=input_size, latent_size=3)
    decoder = Decoder(latent_size=3, output_size=input_size)
    vae = VAE(encoder, decoder, use_cuda=use_cuda)
    criterion = nn.MSELoss()

    if use_cuda:
        encoder = nn.DataParallel(encoder)
        decoder = nn.DataParallel(decoder)
        encoder = encoder.cuda().half()
        decoder = decoder.cuda().half()
        vae = nn.DataParallel(vae)
        vae = vae.cuda().half()
        criterion = criterion.cuda().half()

    optimizer = optim.SGD(vae.parameters(), lr=0.01)

    clock = AverageMeter(name='clock16', rank=0)
    epoch_loss = 0
    total_loss = 0
    end = time.time()
    for epoch in range(15):
        for batch_idx, data in enumerate(trainloader):
            inputs = data['cont_matrix']
            #           inputs = inputs.resize_(args.batch_size, 1, 21, 21)
            inputs = inputs.float()
            if use_cuda:
                inputs = inputs.cuda().half()
            inputs = Variable(inputs)
            optimizer.zero_grad()
            dec = vae(inputs)
            ll = latent_loss(vae.z_mean, vae.z_sigma)
            loss = criterion(dec, inputs) + ll
            loss.backward()
            optimizer.step()
            epoch_loss += loss.data[0]

            clock.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.data[0]))

    clock.save(path='/home/ygx/libraries/mds/molecules/molecules/linear_vae')
Exemplo n.º 5
0
    def __init__(self,
                 args,
                 model,
                 criterion,
                 optimizer=None,
                 ae_criterion=None):
        self.args = args

        # copy model and criterion on current device
        self.model = model.to(self.args.device)
        self.criterion = criterion.to(self.args.device)
        self.ae_criterion = ae_criterion.to(self.args.device)
        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
        self.meters['wps'] = TimeMeter()  # words per second
        self.meters['ups'] = TimeMeter()  # updates per second
        self.meters['wpb'] = AverageMeter()  # words per batch
        self.meters['bsz'] = AverageMeter()  # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()  # % of updates clipped
        self.meters['oom'] = AverageMeter()  # out of memory
        self.meters['wall'] = TimeMeter()  # wall time in seconds

        self._buffered_stats = defaultdict(lambda: [])
        self._flat_grads = None
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
        if optimizer is not None:
            self._optimizer = optimizer

        self.total_loss = 0.0
        self.train_score = 0.0
        self.total_norm = 0.0
        self.count_norm = 0.0
Exemplo n.º 6
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args.decoder_embed_dim,
                                  args.discriminator_hidden_size,
                                  args.discriminator_linear_size,
                                  args.discriminator_lin_dropout,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    def _calcualte_discriminator_loss(tf_scores, ar_scores):
        tf_loss = torch.log(tf_scores + 1e-6) * (-1)
        ar_loss = torch.log(1 - ar_scores + 1e-6) * (-1)
        return tf_loss + ar_loss

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/professorjp'):
        os.makedirs('checkpoints/professorjp')
    checkpoints_path = 'checkpoints/professorjp/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    # d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    # for p in discriminator.embed_src_tokens.parameters():
    #     p.requires_grad = False
    # for p in discriminator.embed_trg_tokens.parameters():
    #     p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(trainloader):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator
            # print("Policy Gradient Training")
            sys_out_batch_PG, p_PG, hidden_list_PG = generator(
                'PG', epoch_i, sample)  # 64 X 50 X 6632

            out_batch_PG = sys_out_batch_PG.contiguous().view(
                -1, sys_out_batch_PG.size(-1))  # (64 * 50) X 6632

            _, prediction = out_batch_PG.topk(1)
            # prediction = prediction.squeeze(1)  # 64*50 = 3200
            # prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape)  # 64 X 50

            with torch.no_grad():
                reward = discriminator(hidden_list_PG)  # 64 X 1

            train_trg_batch_PG = sample['target']  # 64 x 50

            pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG,
                                      reward, use_cuda)
            sample_size_PG = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']  # 64
            logging_loss_PG = pg_loss_PG / math.log(2)
            g_logging_meters['train_loss'].update(logging_loss_PG.item(),
                                                  sample_size_PG)
            logging.debug(
                f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
            )
            g_optimizer.zero_grad()
            pg_loss_PG.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                           args.clip_norm)
            g_optimizer.step()

            # print("MLE Training")

            sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator(
                "MLE", epoch_i, sample)

            out_batch_MLE = sys_out_batch_MLE.contiguous().view(
                -1, sys_out_batch_MLE.size(-1))  # (64 X 50) X 6632

            train_trg_batch_MLE = sample['target'].view(-1)  # 64*50 = 3200
            loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE)

            sample_size_MLE = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2)
            g_logging_meters['bsz'].update(nsentences)
            g_logging_meters['train_loss'].update(logging_loss_MLE,
                                                  sample_size_MLE)
            logging.debug(
                f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
            )
            # g_optimizer.zero_grad()
            loss_MLE.backward(retain_graph=True)
            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                # print(p.size())
                if p.requires_grad:
                    p.grad.data.div_(sample_size_MLE)
            torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                           args.clip_norm)
            g_optimizer.step()

            num_update += 1

            # part II: train the discriminator

            d_MLE = discriminator(hidden_list_MLE)
            d_PG = discriminator(hidden_list_PG)
            d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum()
            logging.debug(f"D training loss {d_loss} at batch {i}")

            d_optimizer.zero_grad()
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                           args.clip_norm)
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        valloader = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(valloader):

            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)

                # generator validation
                sys_out_batch_test, p_test, hidden_list_test = generator(
                    'test', epoch_i, sample)
                out_batch_test = sys_out_batch_test.contiguous().view(
                    -1, sys_out_batch_test.size(-1))  # (64 X 50) X 6632
                dev_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss_test = g_criterion(out_batch_test, dev_trg_batch)
                sample_size_test = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss_test = loss_test / sample_size_test / math.log(2)
                g_logging_meters['valid_loss'].update(loss_test,
                                                      sample_size_test)
                logging.debug(
                    f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                )

                # # discriminator validation
                # bsz = sample['target'].size(0)
                # src_sentence = sample['net_input']['src_tokens']
                # # train with half human-translation and half machine translation
                #
                # true_sentence = sample['target']
                # true_labels = torch.ones(sample['target'].size(0)).float()
                #
                # with torch.no_grad():
                #     sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample)
                #
                # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 X 50) X 6632
                #
                # _, prediction = out_batch.topk(1)
                # prediction = prediction.squeeze(1)  # 64 * 50 = 6632
                #
                # fake_labels = torch.zeros(sample['target'].size(0)).float()
                #
                # fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                #
                # if use_cuda:
                #     fake_labels = fake_labels.cuda()
                #
                # disc_out = discriminator(src_sentence, fake_sentence)
                # d_loss = d_criterion(disc_out.squeeze(1), fake_labels)
                # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # d_logging_meters['valid_acc'].update(acc)
                # d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug(
                #     f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}")

        torch.save(
            generator,
            open(
                checkpoints_path +
                f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt",
                'wb'),
            pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
Exemplo n.º 7
0
    def create_meters(self):
        g_logging_meters = OrderedDict()
        g_logging_meters['train_loss'] = AverageMeter()
        g_logging_meters['valid_loss'] = AverageMeter()
        g_logging_meters['train_acc'] = AverageMeter()
        g_logging_meters['valid_acc'] = AverageMeter()
        g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

        d_logging_meters = OrderedDict()
        d_logging_meters['train_loss'] = AverageMeter()
        d_logging_meters['valid_loss'] = AverageMeter()
        d_logging_meters['train_acc'] = AverageMeter()
        d_logging_meters['valid_acc'] = AverageMeter()
        d_logging_meters['train_bleu'] = AverageMeter()
        d_logging_meters['valid_bleu'] = AverageMeter()
        d_logging_meters['train_rouge'] = AverageMeter()
        d_logging_meters['valid_rouge'] = AverageMeter()
        d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

        self.g_logging_meters = g_logging_meters
        self.d_logging_meters = d_logging_meters
Exemplo n.º 8
0
    writer.add_scalar(tag="lr", scalar_value=optimiser.param_groups[0]["lr"], global_step=e_i)

    for head_i in range(2):
        head = heads[head_i]
        if head == "A":
            dataloaders = dataloaders_head_A
            epoch_loss = config.epoch_loss_head_A
            epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A
        elif head == "B":
            dataloaders = dataloaders_head_B
            epoch_loss = config.epoch_loss_head_B
            epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B
        else:
            raise NotImplemented(head)

        avg_loss_meter = AverageMeter("avg_loss")
        mi_meter = AverageMeter("standard_mi")

        for head_i_epoch in range(head_epochs[head]):
            sys.stdout.flush()

            with tqdm(enumerate(zip(*dataloaders))) as indicator:
                indicator.set_description(f"Head:{head}")
                for b_i, tup in indicator:
                    optimiser.zero_grad()

                    # one less because this is before sobel
                    with autocast:
                        data, label = zip(*tup)
                        all_imgs = torch.cat([data[0] for _ in range(len(data) - 1)]).cuda()
                        all_imgs_tf = torch.cat([data[i] for i in range(1, len(data))]).cuda()
Exemplo n.º 9
0
    def run_one_epoch(self, training):
        tic = time.time()
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        if training:
            amnt = self.num_train
            dataset = self.train_loader
        else:
            dataset = self.val_loader
            amnt = self.num_valid
        with tqdm(total=amnt) as pbar:
            for i, data in enumerate(dataset):
                x, y = data
                # segmentation task
                if self.classification:
                    # assuming one-hot
                    y = y.view(1, -1).expand(self.model.num_heads, -1)
                else:
                    y = y.view(1, -1, 1, x.shape[-2], x.shape[-1]).expand(self.model.num_heads, -1, -1, -1, -1)
                if self.config.use_gpu:
                    x, y = x.cuda(), y.cuda()
                output = self.model(x)
                if training:
                    self.optimizer.zero_grad()
                loss = None

                for head in range(self.model.num_heads):
                    if loss is None:
                        loss = self.criterion(output[head], y[head])
                    else:
                        loss = loss + self.criterion(output[head], y[head])
                loss = loss / self.model.num_heads
                if training:
                    loss.backward()
                    self.optimizer.step()
                try:
                    loss_data = loss.data[0]
                except IndexError:
                    loss_data = loss.data.item()
                losses.update(loss_data)
                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)
                if self.classification:
                    _, predicted = torch.max(output.data, -1)
                    total = self.batch_size*self.model.num_heads
                    correct = (predicted == y).sum().item()
                    acc = correct/total
                    accs.update(acc)
                    pbar.set_description(f"{(toc - tic):.1f}s - loss: {loss_data:.3f} acc {accs.avg:.3f}")
                else:
                    pbar.set_description(f"{(toc - tic):.1f}s - loss: {loss_data:.3f}")
                pbar.update(self.batch_size)
                if training and i % 2 == 0:
                    self.model.log_illumination(self.curr_epoch, i)
                if not training and i == 0 and not self.classification:
                    y_sample = y[0, 0].view(256, 256).detach().cpu().numpy()
                    p_sample = output[0, 0].view(256, 256).detach().cpu().numpy()
                    wandb.log({f"images_epoch{self.curr_epoch}": [
                        wandb.Image(np.round(p_sample * 255), caption="prediction"),
                        wandb.Image(np.round(y_sample * 255), caption="label")]}, step=self.curr_epoch)
        return losses.avg, accs.avg
Exemplo n.º 10
0
Arquivo: main.py Projeto: yngtodd/koda
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--resume',
                        type=bool,
                        default=False,
                        help='Resumes training from savefile.')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    encoder = Encoder2()

    savefile = './savepoints/checkpoint10.pth.tar'

    if args.resume:
        if os.path.isfile(savefile):
            print("=> loading checkpoint '{}'".format(savefile))
            checkpoint = torch.load(savefile)
            encoder.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'".format(savefile))
        else:
            print("=> no checkpoint found at '{}'".format(savefile))

    model = TransferNet(encoder).to(device)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

    train_meter = AverageMeter(name='trainacc')
    test_meter = AverageMeter(name='testacc')

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader, test_meter)

    test_meter.save('./')
Exemplo n.º 11
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    g_model_path = 'checkpoints/zhenwarm/generator.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")
    #
    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/myzhencli5'):
        os.makedirs('checkpoints/myzhencli5')
    checkpoints_path = 'checkpoints/myzhencli5/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(trainloader):

            # set training mode
            generator.train()
            discriminator.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch = generator(sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            # part II: train the discriminator
            if num_update % 5 == 0:
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = Variable(
                    torch.ones(
                        sample['target'].size(0)).float())  # 64 length vector

                with torch.no_grad():
                    sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = Variable(
                    torch.zeros(
                        sample['target'].size(0)).float())  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence,
                                              src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                # fake_disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1
                # true_disc_out = discriminator(src_sentence, true_sentence)
                #
                # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #
                # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (fake_acc + true_acc) / 2
                #
                # d_loss = fake_d_loss + true_d_loss
                if random.random() > 0.5:
                    fake_disc_out = discriminator(src_sentence, fake_sentence)
                    fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                              fake_labels)
                    fake_acc = torch.sum(
                        torch.round(fake_disc_out).squeeze(1) ==
                        fake_labels).float() / len(fake_labels)
                    d_loss = fake_d_loss
                    acc = fake_acc
                else:
                    true_disc_out = discriminator(src_sentence, true_sentence)
                    true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                              true_labels)
                    true_acc = torch.sum(
                        torch.round(true_disc_out).squeeze(1) ==
                        true_labels).float() / len(true_labels)
                    d_loss = true_d_loss
                    acc = true_acc

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

            if num_update % 10000 == 0:

                # validation
                # set validation mode
                generator.eval()
                discriminator.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss = g_criterion(out_batch, dev_trg_batch)
                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss / sample_size / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss, sample_size)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                        # discriminator validation
                        bsz = sample['target'].size(0)
                        src_sentence = sample['net_input']['src_tokens']
                        # train with half human-translation and half machine translation

                        true_sentence = sample['target']
                        true_labels = Variable(
                            torch.ones(sample['target'].size(0)).float())

                        with torch.no_grad():
                            sys_out_batch = generator(sample)

                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                        _, prediction = out_batch.topk(1)
                        prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                        fake_labels = Variable(
                            torch.zeros(sample['target'].size(0)).float())

                        fake_sentence = torch.reshape(
                            prediction, src_sentence.shape)  # 64 X 50
                        true_sentence = torch.reshape(true_sentence,
                                                      src_sentence.shape)
                        if use_cuda:
                            fake_labels = fake_labels.cuda()
                            true_labels = true_labels.cuda()

                        fake_disc_out = discriminator(src_sentence,
                                                      fake_sentence)  # 64 X 1
                        true_disc_out = discriminator(src_sentence,
                                                      true_sentence)

                        fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                                  fake_labels)
                        true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                                  true_labels)
                        d_loss = fake_d_loss + true_d_loss
                        fake_acc = torch.sum(
                            torch.round(fake_disc_out).squeeze(1) ==
                            fake_labels).float() / len(fake_labels)
                        true_acc = torch.sum(
                            torch.round(true_disc_out).squeeze(1) ==
                            true_labels).float() / len(true_labels)
                        acc = (fake_acc + true_acc) / 2
                        d_logging_meters['valid_acc'].update(acc)
                        d_logging_meters['valid_loss'].update(d_loss)
                        logging.debug(
                            f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                        )

                # torch.save(discriminator,
                #            open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill)

                # if d_logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = d_logging_meters['valid_loss'].avg
                #     torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill)

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
Exemplo n.º 12
0
def forward(data_loader,
            model,
            criterion,
            epoch,
            training,
            model_type,
            optimizer=None,
            writer=None):
    if training:
        model.train()
    else:
        model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()

    total_steps = len(data_loader)

    for i, (inputs, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        inputs = inputs.to('cuda:0')
        target = target.to('cuda:0')

        # compute output
        output = model(inputs)
        if model_type == 'int':
            # omit the output exponent
            output, output_exp = output
            output = output.float()
            loss = criterion(output * (2**output_exp.float()), target)
        else:
            output_exp = 0
            loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(float(loss), inputs.size(0))
        prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5))
        top1.update(float(prec1), inputs.size(0))
        top5.update(float(prec5), inputs.size(0))

        if training:
            if model_type == 'int':
                model.backward(target)

            elif model_type == 'hybrid':
                # float backward
                optimizer.update(epoch, epoch * len(data_loader) + i)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #int8 backward
                model.backward()
            else:
                optimizer.update(epoch, epoch * len(data_loader) + i)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_interval == 0 and training:
            logging.info('{model_type} [{0}][{1}/{2}] '
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                         'Data {data_time.val:.2f} '
                         'loss {loss.val:.3f} ({loss.avg:.3f}) '
                         'e {output_exp:d} '
                         '@1 {top1.val:.3f} ({top1.avg:.3f}) '
                         '@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                             epoch,
                             i,
                             len(data_loader),
                             model_type=model_type,
                             batch_time=batch_time,
                             data_time=data_time,
                             loss=losses,
                             output_exp=output_exp,
                             top1=top1,
                             top5=top5))

            if args.grad_hist:
                if args.model_type == 'int':
                    for idx, l in enumerate(model.forward_layers):
                        if hasattr(l, 'weight'):
                            grad = l.grad_int32acc
                            writer.add_histogram(
                                'Grad/' + l.__class__.__name__ + '_' +
                                str(idx), grad, epoch * total_steps + i)

                elif args.model_type == 'float':
                    for idx, l in enumerate(model.layers):
                        if hasattr(l, 'weight'):
                            writer.add_histogram(
                                'Grad/' + l.__class__.__name__ + '_' +
                                str(idx), l.weight.grad,
                                epoch * total_steps + i)
                    for idx, l in enumerate(model.classifier):
                        if hasattr(l, 'weight'):
                            writer.add_histogram(
                                'Grad/' + l.__class__.__name__ + '_' +
                                str(idx), l.weight.grad,
                                epoch * total_steps + i)

    return losses.avg, top1.avg, top5.avg
Exemplo n.º 13
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    if args.gpuid:
        cuda.set_device(args.gpuid[0])

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Build model
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)

    if use_cuda:
        generator.cuda()
    else:
        generator.cpu()

    optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(),
                                                      args.learning_rate)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf

    epoch_i = 1
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']
    # main training loop

    # added for write training loss
    f1 = open("train_loss", "a")

    while lr > args.min_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (min(args.max_source_positions,
                                   generator.encoder.max_positions()),
                               min(args.max_target_positions,
                                   generator.decoder.max_positions()))

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_train,
            seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )
        # set training mode
        generator.train()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            loss = generator(sample)
            sample_size = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss = loss.data / sample_size / math.log(2)
            logging_meters['bsz'].update(nsentences)
            logging_meters['train_loss'].update(logging_loss, sample_size)
            f1.write("{0}\n".format(logging_meters['train_loss'].avg))
            logging.debug(
                "loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format(
                    i, logging_meters['train_loss'].avg,
                    round(logging_meters['bsz'].avg),
                    optimizer.param_groups[0]['lr']))
            optimizer.zero_grad()
            loss.backward()

            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                if p.requires_grad:
                    p.grad.data.div_(sample_size)

            torch.nn.utils.clip_grad_norm(generator.parameters(),
                                          args.clip_norm)
            optimizer.step()

            # validation -- this is a crude estimation because there might be some padding at the end
            max_positions_valid = (
                generator.encoder.max_positions(),
                generator.decoder.max_positions(),
            )

        # Initialize dataloader
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=args.
            skip_invalid_size_inputs_valid_test,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )
        # set validation mode
        generator.eval()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)
                loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("dev loss at batch {0}: {1:.3f}".format(
                    i, logging_meters['valid_loss'].avg))

        # update learning rate
        lr_scheduler.step(logging_meters['valid_loss'].avg)
        lr = optimizer.param_groups[0]['lr']

        logging.info(
            "Average loss value per instance is {0} at the end of epoch {1}".
            format(logging_meters['valid_loss'].avg, epoch_i))
        torch.save(
            generator.state_dict(),
            open(
                args.model_file + "data.nll_{0:.3f}.epoch_{1}.pt".format(
                    logging_meters['valid_loss'].avg, epoch_i), 'wb'))

        if logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = logging_meters['valid_loss'].avg
            torch.save(generator.state_dict(),
                       open(args.model_file + "best_gmodel.pt", 'wb'))

        epoch_i += 1

    f1.close()
def train_d(args, dataset):
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.DEBUG)

    use_cuda = (torch.cuda.device_count() >= 1)

    # check checkpoints saving path
    if not os.path.exists('checkpoints/discriminator'):
        os.makedirs('checkpoints/discriminator')

    checkpoints_path = 'checkpoints/discriminator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['train_acc'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['valid_acc'] = AverageMeter()
    logging_meters['update_times'] = AverageMeter()

    # Build model
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)

    # Load generator
    assert os.path.exists('checkpoints/generator/best_gmodel.pt')
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load('checkpoints/generator/best_gmodel.pt')
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            # generator = torch.nn.DataParallel(generator).cuda()
            generator.cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    criterion = torch.nn.CrossEntropyLoss()

    # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()),
    #                                                     args.d_learning_rate, momentum=args.momentum, nesterov=True)

    optimizer = torch.optim.RMSprop(
        filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the accuracy achieve the define value
    max_epoch = args.max_epoch or math.inf
    epoch_i = 1
    trg_acc = 0.82
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']

    # validation set data loader (only prepare once)
    train = prepare_training_data(args, dataset, 'train', generator, epoch_i,
                                  use_cuda)
    valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i,
                                  use_cuda)
    data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len)
    data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len)

    # main training loop
    while lr > args.min_d_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        if args.sample_without_replacement > 0 and epoch_i > 1:
            train = prepare_training_data(args, dataset, 'train', generator,
                                          epoch_i, use_cuda)
            data_train = DatasetProcessing(data=train,
                                           maxlen=args.fixed_max_len)

        # discriminator training dataloader
        train_loader = train_dataloader(data_train,
                                        batch_size=args.joint_batch_size,
                                        seed=seed,
                                        epoch=epoch_i,
                                        sort_by_source_size=False)

        valid_loader = eval_dataloader(data_valid,
                                       num_workers=4,
                                       batch_size=args.joint_batch_size)

        # set training mode
        discriminator.train()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(train_loader):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=use_cuda)

            disc_out = discriminator(sample['src_tokens'],
                                     sample['trg_tokens'])

            loss = criterion(disc_out, sample['labels'])
            _, prediction = F.softmax(disc_out, dim=1).topk(1)
            acc = torch.sum(
                prediction == sample['labels'].unsqueeze(1)).float() / len(
                    sample['labels'])

            logging_meters['train_acc'].update(acc.item())
            logging_meters['train_loss'].update(loss.item())
            logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                          format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg,
                                 optimizer.param_groups[0]['lr'], i))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(discriminator.parameters(),
                                          args.clip_norm)
            optimizer.step()

            # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc
            del disc_out, loss, prediction, acc

        # set validation mode
        discriminator.eval()

        for i, sample in enumerate(valid_loader):
            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=use_cuda)

                disc_out = discriminator(sample['src_tokens'],
                                         sample['trg_tokens'])

                loss = criterion(disc_out, sample['labels'])
                _, prediction = F.softmax(disc_out, dim=1).topk(1)
                acc = torch.sum(
                    prediction == sample['labels'].unsqueeze(1)).float() / len(
                        sample['labels'])

                logging_meters['valid_acc'].update(acc.item())
                logging_meters['valid_loss'].update(loss.item())
                logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                              format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg,
                                     optimizer.param_groups[0]['lr'], i))

            del disc_out, loss, prediction, acc

        lr_scheduler.step(logging_meters['valid_loss'].avg)

        if logging_meters['valid_acc'].avg >= 0.70:
            torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \
                       .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i))

            if logging_meters['valid_loss'].avg < best_dev_loss:
                best_dev_loss = logging_meters['valid_loss'].avg
                torch.save(discriminator.state_dict(),
                           checkpoints_path + "best_dmodel.pt")

        # pretrain the discriminator to achieve accuracy 82%
        if logging_meters['valid_acc'].avg >= trg_acc:
            return

        epoch_i += 1
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['MLE_train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['PG_train_loss'] = AverageMeter()
    g_logging_meters['MLE_train_acc'] = AverageMeter()
    g_logging_meters['PG_train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['D_h_train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['D_s_train_loss'] = AverageMeter()
    d_logging_meters['D_h_train_acc'] = AverageMeter()
    d_logging_meters['D_s_train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0.1
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0.1
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator_h = Discriminator_h(args.decoder_embed_dim,
                                      args.discriminator_hidden_size,
                                      args.discriminator_linear_size,
                                      args.discriminator_lin_dropout,
                                      use_cuda=use_cuda)
    print("Discriminator_h loaded successfully!")
    discriminator_s = Discriminator_s(args,
                                      dataset.src_dict,
                                      dataset.dst_dict,
                                      use_cuda=use_cuda)
    print("Discriminator_s loaded successfully!")

    # Load generator model
    g_model_path = 'checkpoints/zhenwarm/genev.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")

    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator_s.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator_s.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    # # Load discriminatorH model
    # d_H_model_path = 'checkpoints/joint_warm/DH.pt'
    # assert os.path.exists(d_H_model_path)
    # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    # d_H_model_dict = discriminator_h.state_dict()
    # d_H_model = torch.load(d_H_model_path)
    # d_H_pretrained_dict = d_H_model.state_dict()
    # # 1. filter out unnecessary keys
    # d_H_pretrained_dict = {k: v for k, v in d_H_pretrained_dict.items() if k in d_H_model_dict}
    # # 2. overwrite entries in the existing state dict
    # d_H_model_dict.update(d_H_pretrained_dict)
    # # 3. load the new state dict
    # discriminator_h.load_state_dict(d_H_model_dict)
    # print("pre-trained Discriminator_H loaded successfully!")

    def _calcualte_discriminator_loss(tf_scores, ar_scores):
        tf_loss = torch.log(tf_scores + 1e-9) * (-1)
        ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1)
        return tf_loss + ar_loss

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator_h = torch.nn.DataParallel(discriminator_h).cuda()
            discriminator_s = torch.nn.DataParallel(discriminator_s).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator_h.cuda()
            discriminator_s.cuda()
    else:
        discriminator_h.cpu()
        discriminator_s.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/realmyzhenup10shrink07drop01new'):
        os.makedirs('checkpoints/realmyzhenup10shrink07drop01new')
    checkpoints_path = 'checkpoints/realmyzhenup10shrink07drop01new/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator_s word embedding (as Wu et al. do)
    for p in discriminator_s.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator_s.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer_h = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator_h.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    d_optimizer_s = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator_s.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode

        for i, sample in enumerate(trainloader):
            generator.train()
            discriminator_h.train()
            discriminator_s.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)
            train_MLE = 0
            train_PG = 0

            if random.random() > 0.5 and i != 0:
                train_MLE = 1
                # print("MLE Training")
                sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator(
                    "MLE", epoch_i, sample)

                out_batch_MLE = sys_out_batch_MLE.contiguous().view(
                    -1, sys_out_batch_MLE.size(-1))  # (64 X 50) X 6632

                train_trg_batch_MLE = sample['target'].view(-1)  # 64*50 = 3200
                loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE)

                sample_size_MLE = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(
                    2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['MLE_train_loss'].update(
                    logging_loss_MLE, sample_size_MLE)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['MLE_train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss_MLE.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size_MLE)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()
            else:
                train_PG = 1
                ## part I: use gradient policy method to train the generator
                # print("Policy Gradient Training")
                sys_out_batch_PG, p_PG, hidden_list_PG = generator(
                    'PG', epoch_i, sample)  # 64 X 50 X 6632

                out_batch_PG = sys_out_batch_PG.contiguous().view(
                    -1, sys_out_batch_PG.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch_PG.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50
                # if d_logging_meters['train_acc'].avg >= 0.75:
                with torch.no_grad():
                    reward = discriminator_s(sample['net_input']['src_tokens'],
                                             prediction)  # 64 X 1
                # else:
                #     reward = torch.ones(args.joint_batch_size)
                train_trg_batch_PG = sample['target']  # 64 x 50

                pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG,
                                          reward, use_cuda)
                sample_size_PG = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss_PG = pg_loss_PG / math.log(2)
                g_logging_meters['PG_train_loss'].update(
                    logging_loss_PG.item(), sample_size_PG)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                if g_logging_meters['PG_train_loss'].val < 0.5:
                    pg_loss_PG.backward()
                    torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                                   args.clip_norm)
                    g_optimizer.step()

            num_update += 1

            # if g_logging_meters["MLE_train_loss"].avg < 4:
            #  part II: train the discriminator
            # discriminator_h
            if num_update % 10 == 0:
                if g_logging_meters["PG_train_loss"].val < 2:
                    assert (train_MLE == 1) != (train_PG == 1)
                    if train_MLE == 1:
                        d_MLE = discriminator_h(hidden_list_MLE.detach())
                        M_loss = torch.log(d_MLE + 1e-9) * (-1)
                        h_d_loss = M_loss.sum()
                    elif train_PG == 1:
                        d_PG = discriminator_h(hidden_list_PG.detach())
                        P_loss = torch.log(1 - d_PG + 1e-9) * (-1)
                        h_d_loss = P_loss.sum()

                    # d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum()
                    logging.debug(f"D_h training loss {h_d_loss} at batch {i}")

                    d_optimizer_h.zero_grad()

                    h_d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        discriminator_h.parameters(), args.clip_norm)
                    d_optimizer_h.step()

                #discriminator_s
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output

                with torch.no_grad():
                    sys_out_batch, p, hidden_list = generator(
                        'PG', epoch_i, sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = torch.zeros(
                    sample['target'].size(0)).float()  # 64 length vector
                # true_labels = torch.ones(sample['target'].size(0)).float()  # 64 length vector
                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50

                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    # true_labels = true_labels.cuda()

                # if random.random() > 0.5:
                fake_disc_out = discriminator_s(src_sentence,
                                                fake_sentence)  # 64 X 1
                fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                          fake_labels)
                acc = torch.sum(
                    torch.round(fake_disc_out).squeeze(1) ==
                    fake_labels).float() / len(fake_labels)
                d_loss = fake_d_loss
                # else:
                #
                #     true_sentence = sample['target'].view(-1)  # 64*50 = 3200
                #
                #     true_sentence = torch.reshape(true_sentence, src_sentence.shape)
                #     true_disc_out = discriminator_s(src_sentence, true_sentence)
                #     acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                #     true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #     d_loss = true_d_loss

                # acc_fake = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # acc_true = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (acc_fake + acc_true) / 2
                # acc = acc_fake
                # d_loss = fake_d_loss + true_d_loss
                s_d_loss = d_loss
                d_logging_meters['D_s_train_acc'].update(acc)

                d_logging_meters['D_s_train_loss'].update(s_d_loss)
                logging.debug(
                    f"D_s training loss {d_logging_meters['D_s_train_loss'].avg:.3f}, acc {d_logging_meters['D_s_train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer_s.zero_grad()
                s_d_loss.backward()
                d_optimizer_s.step()

            if num_update % 10000 == 0:
                # validation
                # set validation mode
                print(
                    'validation and save+++++++++++++++++++++++++++++++++++++++++++++++'
                )
                generator.eval()
                discriminator_h.eval()
                discriminator_s.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch_test, p_test, hidden_list_test = generator(
                            'test', epoch_i, sample)
                        out_batch_test = sys_out_batch_test.contiguous().view(
                            -1,
                            sys_out_batch_test.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss_test = g_criterion(out_batch_test, dev_trg_batch)
                        sample_size_test = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss_test = loss_test / sample_size_test / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss_test, sample_size_test)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}w.sampling_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
Exemplo n.º 16
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    print("======printing args========")
    print(args)
    print("=================================")

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        print("Loading bin dataset")
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    else:
        print(f"Loading raw text dataset {args.data}")
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # try to load generator model
    g_model_path = 'checkpoints/generator/best_gmodel.pt'
    if not os.path.exists(g_model_path):
        print("Start training generator!")
        train_g(args, dataset)
    assert os.path.exists(g_model_path)
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(g_model_path)
    #print(f"First dict: {pretrained_dict}")
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    #print(f"Second dict: {pretrained_dict}")
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    #print(f"model dict: {model_dict}")
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    print("Generator has successfully loaded!")

    # try to load discriminator model
    d_model_path = 'checkpoints/discriminator/best_dmodel.pt'
    if not os.path.exists(d_model_path):
        print("Start training discriminator!")
        train_d(args, dataset)
    assert os.path.exists(d_model_path)
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    model_dict = discriminator.state_dict()
    pretrained_dict = torch.load(d_model_path)
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(model_dict)

    print("Discriminator has successfully loaded!")

    #return
    print("starting main training loop")

    torch.autograd.set_detect_anomaly(True)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/joint'):
        os.makedirs('checkpoints/joint')
    checkpoints_path = 'checkpoints/joint/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(size_average=False,
                                   ignore_index=dataset.dst_dict.pad(),
                                   reduce=True)
    d_criterion = torch.nn.BCEWithLogitsLoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        # seed = args.seed + epoch_i
        # torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(itr):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when rand > 50%
            rand = random.random()
            if rand >= 0.5:
                # policy gradient training
                generator.decoder.is_testing = True
                sys_out_batch, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
                with torch.no_grad():
                    n_i = sample['net_input']['src_tokens']
                    #print(f"net input:\n{n_i}, pred: \n{prediction}")
                    reward = discriminator(
                        sample['net_input']['src_tokens'],
                        prediction)  # dataset.dst_dict.pad())
                train_trg_batch = sample['target']
                #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}")
                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()

                # oracle valid
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
            else:
                # MLE training
                #print(f"printing sample: \n{sample}")
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()
            num_update += 1

            # part II: train the discriminator
            bsz = sample['target'].size(0)
            src_sentence = sample['net_input']['src_tokens']
            # train with half human-translation and half machine translation

            true_sentence = sample['target']
            true_labels = Variable(
                torch.ones(sample['target'].size(0)).float())

            with torch.no_grad():
                generator.decoder.is_testing = True
                _, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
            fake_sentence = prediction
            fake_labels = Variable(
                torch.zeros(sample['target'].size(0)).float())

            trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
            labels = torch.cat([true_labels, fake_labels], dim=0)

            indices = np.random.permutation(2 * bsz)
            trg_sentence = trg_sentence[indices][:bsz]
            labels = labels[indices][:bsz]

            if use_cuda:
                labels = labels.cuda()

            disc_out = discriminator(src_sentence,
                                     trg_sentence)  #, dataset.dst_dict.pad())
            #print(f"disc out: {disc_out.shape}, labels: {labels.shape}")
            #print(f"labels: {labels}")
            d_loss = d_criterion(disc_out, labels.long())
            acc = torch.sum(torch.Sigmoid()
                            (disc_out).round() == labels).float() / len(labels)
            d_logging_meters['train_acc'].update(acc)
            d_logging_meters['train_loss'].update(d_loss)
            # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg,
            #                                                                            d_logging_meters['train_acc'].avg,
            #                                                                            i))
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    sample['id'] = sample['id'].cuda()
                    sample['net_input']['src_tokens'] = sample['net_input'][
                        'src_tokens'].cuda()
                    sample['net_input']['src_lengths'] = sample['net_input'][
                        'src_lengths'].cuda()
                    sample['net_input']['prev_output_tokens'] = sample[
                        'net_input']['prev_output_tokens'].cuda()
                    sample['target'] = sample['target'].cuda()

                # generator validation
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("G dev loss at batch {0}: {1:.3f}".format(
                    i, g_logging_meters['valid_loss'].avg))

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = Variable(
                    torch.ones(sample['target'].size(0)).float())

                with torch.no_grad():
                    generator.decoder.is_testing = True
                    _, prediction, _ = generator(sample)
                    generator.decoder.is_testing = False
                fake_sentence = prediction
                fake_labels = Variable(
                    torch.zeros(sample['target'].size(0)).float())

                trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
                labels = torch.cat([true_labels, fake_labels], dim=0)

                indices = np.random.permutation(2 * bsz)
                trg_sentence = trg_sentence[indices][:bsz]
                labels = labels[indices][:bsz]

                if use_cuda:
                    labels = labels.cuda()

                disc_out = discriminator(src_sentence, trg_sentence,
                                         dataset.dst_dict.pad())
                d_loss = d_criterion(disc_out, labels)
                acc = torch.sum(torch.Sigmoid()(disc_out).round() ==
                                labels).float() / len(labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg,
                #                                                                     d_logging_meters['valid_acc'].avg, i))

        torch.save(generator,
                   open(
                       checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format(
                           g_logging_meters['valid_loss'].avg, epoch_i), 'wb'),
                   pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
Exemplo n.º 17
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

        # check checkpoints saving path
    if not os.path.exists('checkpoints/generator'):
        os.makedirs('checkpoints/generator')

    checkpoints_path = 'checkpoints/generator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['bsz'] = AverageMeter()  # sentences per batch
    logging_meters['update_times'] = AverageMeter()

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # Build model
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)

    # g_model_path = 'checkpoints/generator/numupdate1.4180668458302803.data.nll_105000.000.pt'
    # assert os.path.exists(g_model_path)
    # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    # model_dict = generator.state_dict()
    # model = torch.load(g_model_path)
    # pretrained_dict = model.state_dict()
    # # 1. filter out unnecessary keys
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # # 2. overwrite entries in the existing state dict
    # model_dict.update(pretrained_dict)
    # # 3. load the new state dict
    # generator.load_state_dict(model_dict)
    # print("pre-trained Generator loaded successfully!")

    if use_cuda:
        if len(args.gpuid) > 1:
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
    else:
        generator.cpu()

    print("Training generator...")

    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')

    optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(),
                                                      args.learning_rate)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf

    epoch_i = 1
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']
    num_update = 0
    # main training loop
    while lr > args.min_g_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (min(args.max_source_positions,
                                   generator.encoder.max_positions()),
                               min(args.max_target_positions,
                                   generator.decoder.max_positions()))

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_train,
            seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # set training mode

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            generator.train()
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            sys_out_batch = generator(sample)
            out_batch = sys_out_batch.contiguous().view(
                -1, sys_out_batch.size(-1))
            train_trg_batch = sample['target'].view(-1)
            loss = g_criterion(out_batch, train_trg_batch)
            sample_size = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss = loss.item() / sample_size / math.log(2)
            logging_meters['bsz'].update(nsentences)
            logging_meters['train_loss'].update(logging_loss, sample_size)
            logging.debug(
                "g loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format(
                    i, logging_meters['train_loss'].avg,
                    round(logging_meters['bsz'].avg),
                    optimizer.param_groups[0]['lr']))
            optimizer.zero_grad()
            loss.backward()

            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                if p.requires_grad:
                    p.grad.data.div_(sample_size)

            torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                           args.clip_norm)
            optimizer.step()

            num_update = num_update + 1
            if num_update % 5000 == 0:
                # validation -- this is a crude estimation because there might be some padding at the end
                max_positions_valid = (
                    generator.encoder.max_positions(),
                    generator.decoder.max_positions(),
                )

                # Initialize dataloader
                itr = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.max_sentences,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=args.
                    skip_invalid_size_inputs_valid_test,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )
                # set validation mode
                generator.eval()

                # reset meters
                for key, val in logging_meters.items():
                    if val is not None:
                        val.reset()
                with torch.no_grad():
                    for i, sample in enumerate(itr):
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))

                        val_trg_batch = sample['target'].view(-1)
                        loss = g_criterion(out_batch, val_trg_batch)

                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss.item() / sample_size / math.log(2)
                        logging_meters['valid_loss'].update(loss, sample_size)
                        logging.debug(
                            "g dev loss at batch {0}: {1:.3f}".format(
                                i, logging_meters['valid_loss'].avg))

                # update learning rate
                lr_scheduler.step(logging_meters['valid_loss'].avg)
                lr = optimizer.param_groups[0]['lr']

                logging.info(
                    "Average g loss value per instance is {0} at the end of epoch {1}"
                    .format(logging_meters['valid_loss'].avg, epoch_i))
                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        "numupdate{1}.data.nll_{0:.1f}.pt".format(
                            num_update, logging_meters['valid_loss'].avg),
                        'wb'))

                # if logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = logging_meters['valid_loss'].avg
                #     torch.save(generator.state_dict(), open(
                #         checkpoints_path + "best_gmodel.pt", 'wb'))

        epoch_i += 1
Exemplo n.º 18
0
def train_g(args, dataset):
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.DEBUG)

    use_cuda = (cuda.device_count() >= 1)

    # check checkpoints saving path
    if not os.path.exists('checkpoints/generator'):
        os.makedirs('checkpoints/generator')

    checkpoints_path = 'checkpoints/generator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['bsz'] = AverageMeter()  # sentences per batch
    logging_meters['update_times'] = AverageMeter()

    # Build model
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)

    if use_cuda:
        if len(args.gpuid) > 1:
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
    else:
        generator.cpu()

    optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(),
                                                      args.learning_rate)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf

    epoch_i = 1
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']
    # main training loop
    while lr > args.min_g_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (min(args.max_source_positions,
                                   generator.encoder.max_positions()),
                               min(args.max_target_positions,
                                   generator.decoder.max_positions()))

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_train,
            seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )
        # set training mode
        generator.train()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            loss = generator(sample)
            sample_size = sample['target'].size(
                0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss = loss.item() / sample_size / math.log(2)
            logging_meters['bsz'].update(nsentences)
            logging_meters['train_loss'].update(logging_loss, sample_size)
            logging.debug(
                "g loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format(
                    i, logging_meters['train_loss'].avg,
                    round(logging_meters['bsz'].avg),
                    optimizer.param_groups[0]['lr']))
            optimizer.zero_grad()
            loss.backward()

            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                if p.requires_grad:
                    p.grad.data.div_(sample_size)

            torch.nn.utils.clip_grad_norm(generator.parameters(),
                                          args.clip_norm)
            optimizer.step()

        # validation -- this is a crude estimation because there might be some padding at the end
        max_positions_valid = (
            generator.encoder.max_positions(),
            generator.decoder.max_positions(),
        )

        # Initialize dataloader
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=args.
            skip_invalid_size_inputs_valid_test,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )
        # set validation mode
        generator.eval()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()
        with torch.no_grad():
            for i, sample in enumerate(itr):
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)
                loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss.item() / sample_size / math.log(2)
                logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("g dev loss at batch {0}: {1:.3f}".format(
                    i, logging_meters['valid_loss'].avg))

        # update learning rate
        lr_scheduler.step(logging_meters['valid_loss'].avg)
        lr = optimizer.param_groups[0]['lr']

        logging.info(
            "Average g loss value per instance is {0} at the end of epoch {1}".
            format(logging_meters['valid_loss'].avg, epoch_i))
        torch.save(
            generator.state_dict(),
            open(
                checkpoints_path + "data.nll_{0:.3f}.epoch_{1}.pt".format(
                    logging_meters['valid_loss'].avg, epoch_i), 'wb'))

        if logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = logging_meters['valid_loss'].avg
            torch.save(generator.state_dict(),
                       open(checkpoints_path + "best_gmodel.pt", 'wb'))

        epoch_i += 1
Exemplo n.º 19
0
    def train(self,
              epoch,
              data_loader,
              opt_sn,
              opt_vn,
              mode,
              writer=None,
              print_freq=1):
        self.sn.train()
        self.vn.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_sn = AverageMeter()
        losses_vn = AverageMeter()
        ious = AverageMeter()

        end = time.time()

        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - end)

            img, lbl = self._parse_data(inputs)

            # train sn
            loss_sn, iou_, heat_map = self._forward_sn(img, lbl)
            losses_sn.update(loss_sn.data[0], lbl.size(0))
            ious.update(iou_, lbl.size(0))

            if mode == 'sn':
                # if opt_sn is None:
                #     img.volatile = True
                #     lbl.volatile = True
                # else:
                #     img.volatile = False
                #     lbl.volatile = False

                self.step(opt_sn, loss_sn)
            # train vn
            elif mode == 'vn':
                # heat_map = heat_map.detach()
                _, seg_pred = torch.max(heat_map, dim=1, keepdim=True)
                # seg_pred = onehot(seg_pred, 2)
                # heat_map = heat_map
                target_iou = iou(heat_map.data, lbl.data, average=False)

                loss_vn, iou_pred = self._forward_vn(img, heat_map, target_iou)
                losses_vn.update(loss_vn.data[0], lbl.size(0))
                self.step(opt_vn, loss_vn)

            # bp % gd
            # if opt_sn is not None:
            #     self.step(opt_sn, loss_sn)
            # if opt_vn is not None:
            #     self.step(opt_vn, loss_vn)

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'Loss_sn {:.3f} ({:.3f})\t'
                      'Loss_vn {:.3f} ({:.3f})\t'
                      'Prec {:.2%} ({:.2%})\t'.format(
                          epoch, i + 1, len(data_loader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg,
                          losses_sn.val, losses_sn.avg, losses_vn.val,
                          losses_vn.avg, ious.val, ious.avg))

        if writer is not None:
            summary_output_lbl(seg_pred.data, lbl.data, writer, epoch)
Exemplo n.º 20
0
def train_model(output_path,
                model,
                dataloaders,
                dataset_sizes,
                criterion,
                optimizer,
                num_epochs=5,
                scheduler=None):
    if not os.path.exists('iterations/' + str(output_path) + '/saved'):
        os.makedirs('iterations/' + str(output_path) + '/saved')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    losses = AverageMeter()
    accuracies = AverageMeter()
    all_preds = []
    all_labels = []
    val_auc_all = []
    val_acc_all = []
    test_auc_all = []
    test_acc_all = []
    TPFPFN0_all = []
    TPFPFN1_all = []
    best_val_auc = 0.0
    best_epoch = 0
    for epoch in range(1, num_epochs + 1):
        print('-' * 50)
        print('Epoch {}/{}'.format(epoch, num_epochs))
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            # tqdm_loader = tqdm(dataloaders[phase])
            # for data in tqdm_loader:
            #     inputs, labels = data
            for i, (inputs, labels) in enumerate(dataloaders[phase]):

                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                labels_onehot = torch.nn.functional.one_hot(labels,
                                                            num_classes=2)
                labels_onehot = labels_onehot.type(torch.FloatTensor)

                # BCEloss = torch.nn.functional.binary_cross_entropy_with_logits(outputs.cpu(), labels_onehot, torch.FloatTensor([1.0, 1.0]))
                BCEloss = criterion(outputs.cpu(), labels_onehot)
                # print("BCEloss", BCEloss)
                BCEloss_rank = binary_crossentropy_with_ranking(
                    outputs, labels_onehot)
                # print("BCEloss_rank", BCEloss_rank)
                # BCEloss_rank.requires_grad = True
                loss = BCEloss + 0 * BCEloss_rank
                # print("BCEloss, BCEloss_rank", BCEloss, BCEloss_rank)
                # loss = (BCEloss_rank + 1) * BCEloss

                loss.backward()
                optimizer.step()

                losses.update(loss.item(), inputs.size(0))
                acc = float(torch.sum(preds == labels.data)) / preds.shape[0]
                accuracies.update(acc)
                all_preds += list(
                    torch.nn.functional.softmax(outputs,
                                                dim=1)[:,
                                                       1].cpu().data.numpy())
                all_labels += list(labels.cpu().data.numpy())
                # tqdm_loader.set_postfix(loss=losses.avg, acc=accuracies.avg)

            auc = roc_auc_score(all_labels, all_preds)

            if phase == 'train':
                auc_t = auc
                loss_t = losses.avg
                acc_t = accuracies.avg
            if phase == 'val':
                auc_v = auc
                loss_v = losses.avg
                acc_v = accuracies.avg
                val_acc_all.append(acc_v)
                val_auc_all.append(auc_v)

        print('Train AUC: {:.8f} Loss: {:.8f} ACC: {:.8f} '.format(
            auc_t, loss_t, acc_t))
        print('Val AUC: {:.8f} Loss: {:.8f} ACC: {:.8f} '.format(
            auc_v, loss_v, acc_v))
        if auc_v > best_val_auc:
            best_val_auc = auc_v
            best_epoch = epoch
            # print(auc_v, best_val_auc)
            # print(best_epoch)
            best_model = copy.deepcopy(model)

        torch.save(
            model.module, './iterations/' + str(output_path) +
            '/saved/model_{}_epoch.pt'.format(epoch))
        # ############################################################################################################# Test
        for phase in ['test']:
            model.eval()  # Set model to evaluate mode

            for i, (inputs, labels) in enumerate(dataloaders[phase]):

                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(False):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs.data, 1)

                acc = float(torch.sum(preds == labels.data)) / preds.shape[0]
                accuracies.update(acc)
                all_preds += list(
                    torch.nn.functional.softmax(outputs,
                                                dim=1)[:,
                                                       1].cpu().data.numpy())
                all_labels += list(labels.cpu().data.numpy())
                # tqdm_loader.set_postfix(loss=losses.avg, acc=accuracies.avg)

            auc = roc_auc_score(all_labels, all_preds)

            auc_test = auc
            loss_test = losses.avg
            acc_test = accuracies.avg
            test_acc_all.append(acc_test)
            test_auc_all.append(auc_test)

        print('Test AUC: {:.8f} Loss: {:.8f} ACC: {:.8f} '.format(
            auc_test, loss_test, acc_test))

        nb_classes = 2
        confusion_matrix = torch.zeros(nb_classes, nb_classes)
        with torch.no_grad():
            TrueP0 = 0
            FalseP0 = 0
            FalseN0 = 0
            TrueP1 = 0
            FalseP1 = 0
            FalseN1 = 0
            for i, (inputs, classes) in enumerate(dataloaders[phase]):
                confusion_matrix = torch.zeros(nb_classes, nb_classes)
                input = inputs.to(device)
                target = classes.to(device)
                outputs = model(input)
                _, preds = torch.max(outputs, 1)
                for t, p in zip(target.view(-1), preds.view(-1)):
                    confusion_matrix[t, p] += 1
                this_class = 0
                col = confusion_matrix[:, this_class]
                row = confusion_matrix[this_class, :]
                TP = row[this_class]
                FN = sum(row) - TP
                FP = sum(col) - TP
                # print("TP, FP, FN: ", TP, FP, FN)
                TrueP0 = TrueP0 + TP
                FalseP0 = FalseP0 + FP
                FalseN0 = FalseN0 + FN

                this_class = 1
                col = confusion_matrix[:, this_class]
                row = confusion_matrix[this_class, :]
                TP = row[this_class]
                FN = sum(row) - TP
                FP = sum(col) - TP
                # print("TP, FP, FN: ", TP, FP, FN)
                TrueP1 = TrueP1 + TP
                FalseP1 = FalseP1 + FP
                FalseN1 = FalseN1 + FN
            TPFPFN0 = [TrueP0, FalseP0, FalseN0]
            TPFPFN1 = [TrueP1, FalseP1, FalseN1]
            TPFPFN0_all.append(TPFPFN0)
            TPFPFN1_all.append(TPFPFN1)
            print("overall_TP, FP, FN for 0: ", TrueP0, FalseP0, FalseN0)
            print("overall_TP, FP, FN for 1: ", TrueP1, FalseP1, FalseN1)

    print("best_ValidationEpoch:", best_epoch)
    # print(TPFPFN0_all, val_auc_all, test_auc_all)
    TPFPFN0_best = TPFPFN0_all[best_epoch - 1][0]
    TPFPFN1_best = TPFPFN1_all[best_epoch - 1][0]
    val_auc_best = val_auc_all[best_epoch - 1]
    val_acc_best = val_acc_all[best_epoch - 1]
    test_auc_best = test_auc_all[best_epoch - 1]
    test_acc_best = test_acc_all[best_epoch - 1]

    # #################### save only the best, delete others
    file_path = './iterations/' + str(output_path) + '/saved/model_' + str(
        best_epoch) + '_epoch.pt'
    if os.path.isfile(file_path):
        for CleanUp in glob.glob('./iterations/' + str(output_path) +
                                 '/saved/*.pt'):
            if 'model_' + str(best_epoch) + '_epoch.pt' not in CleanUp:
                os.remove(CleanUp)
    # # ######################################################

    return best_epoch, best_model, TPFPFN0_all[best_epoch - 1], TPFPFN1_all[
        best_epoch - 1], test_acc_best, test_auc_best


# def binary_crossentropy_with_ranking(y_true, y_pred):
#     """ Trying to combine ranking loss with numeric precision"""
#     # first get the log loss like normal
#     logloss = K.mean(K.binary_crossentropy(y_pred, y_true), axis=-1)
#
#     # next, build a rank loss
#
#     # clip the probabilities to keep stability
#     y_pred_clipped = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
#
#     # translate into the raw scores before the logit
#     y_pred_score = K.log(y_pred_clipped / (1 - y_pred_clipped))
#
#     # determine what the maximum score for a zero outcome is
#     y_pred_score_zerooutcome_max = K.max(y_pred_score * (y_true < 1))
#
#     # determine how much each score is above or below it
#     rankloss = y_pred_score - y_pred_score_zerooutcome_max
#
#     # only keep losses for positive outcomes
#     rankloss = rankloss * y_true
#
#     # only keep losses where the score is below the max
#     rankloss = K.square(K.clip(rankloss, -100, 0))
#
#     # average the loss for just the positive outcomes
#     rankloss = K.sum(rankloss, axis=-1) / (K.sum(y_true > 0) + 1)
#
#     # return (rankloss + 1) * logloss - an alternative to try
#     return rankloss + logloss