Beispiel #1
0
def main():
    parser = get_args()
    args, unparsed = parser.parse_known_args()
    if len(unparsed) != 0:
        raise NameError("Argument {} not recognized".format(unparsed))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cpu:
        device = torch.device('cpu')
    else:
        if not torch.cuda.is_available():
            raise RuntimeError("GPU unavailable.")

        args.devices = torch.cuda.device_count()
        args.batch_size *= args.devices
        torch.backends.cudnn.benchmark = True
        device = torch.device('cuda')
        torch.cuda.manual_seed(args.seed)

    train_loader, val_loader, vocab_size, num_answers = prepare_data(args)

    model = Model(vocab_size, args.word_embed_dim, args.hidden_size, args.resnet_out)
    model = nn.DataParallel(model).to(device)

    if args.resume:
        print("Initialized from ckpt: " + args.resume)
        ckpt = torch.load(args.resume, map_location=device)
        model.load_state_dict(ckpt['state_dict'], strict=False)

    evaluate(train_loader, model, device, "train")
    evaluate(val_loader, model, device, "val")
Beispiel #2
0
    def evaluate(self, dataloader, image_encoder, text_encoder):
        image_encoder.eval()
        text_encoder.eval()

        s_total_loss = 0
        w_total_loss = 0

        for step, data in enumerate(dataloader, 0):
            real_imgs, captions, class_ids, input_mask = prepare_data(
                data, self.device)

            words_features, sent_code = image_encoder(real_imgs[-1])

            batch_size = words_features.size(0)
            words_emb, sent_emb = self.text_enc_forward(
                text_encoder, captions, input_mask)
            labels = Variable(torch.LongTensor(range(batch_size))).to(
                self.device)

            w_loss0, w_loss1, attn = words_loss(words_features, words_emb,
                                                labels, class_ids, batch_size)
            w_total_loss += (w_loss0 + w_loss1).data

            s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels,
                                         class_ids, batch_size)
            s_total_loss += (s_loss0 + s_loss1).data

        s_cur_loss = s_total_loss.item() / step
        w_cur_loss = w_total_loss.item() / step

        return s_cur_loss, w_cur_loss
Beispiel #3
0
def main():

    parser = get_args()
    args, unparsed = parser.parse_known_args()
    if len(unparsed) != 0:
        raise NameError("Argument {} not recognized".format(unparsed))

    logger = GOATLogger(args.mode, args.save, args.log_freq)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cpu:
        device = torch.device('cpu')
    else:
        if not torch.cuda.is_available():
            raise RuntimeError("GPU unavailable.")

        args.devices = torch.cuda.device_count()
        args.batch_size *= args.devices
        torch.backends.cudnn.benchmark = True
        device = torch.device('cuda')
        torch.cuda.manual_seed(args.seed)

    # Get data
    train_loader, val_loader, vocab_size, num_answers = prepare_data(args)

    # Set up model
    model = Model(vocab_size, args.word_embed_dim, args.hidden_size,
                  args.resnet_out, num_answers)
    model = nn.DataParallel(model).to(device)
    logger.loginfo("Parameters: {:.3f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))

    # Set up optimizer
    optim = torch.optim.Adamax(model.parameters(), lr=2e-3)

    last_epoch = 0
    bscore = 0.0

    if args.resume:
        logger.loginfo("Initialized from ckpt: " + args.resume)
        ckpt = torch.load(args.resume, map_location=device)
        last_epoch = ckpt['epoch']
        model.load_state_dict(ckpt['state_dict'])
        optim.load_state_dict(ckpt['optim_state_dict'])

    if args.mode == 'eval':
        _ = evaluate(val_loader, model, last_epoch, device, logger,
                     args.data_root)
        return

    # Train
    for epoch in range(last_epoch, args.epoch):
        train(train_loader, model, optim, epoch, device, logger)
        score = evaluate(val_loader, model, epoch, device, logger)
        bscore = save_ckpt(score, bscore, epoch, model, optim, args.save,
                           logger)

    logger.loginfo("Done")
Beispiel #4
0
    def validate(self, netG, netsD, text_encoder, image_encoder):
        batch_size = self.batch_size
        nz = self.opts.GAN.Z_DIM
        real_labels, fake_labels, match_labels = self.prepare_labels()

        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        noise, fixed_noise = noise.to(self.device), fixed_noise.to(self.device)

        val_batches = len(self.val_loader)
        netG.eval()
        for i in range(len(netsD)):
            netsD[i].eval()

        inception_scorer = InceptionScore(val_batches, batch_size, val_batches)
        total_loss = []
        with torch.no_grad():
            for step, data in enumerate(self.val_loader):
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                imgs, captions, class_ids, input_mask = prepare_data(
                    data, self.device)

                words_embs, sent_emb = self.text_encoder_forward(
                    text_encoder, captions, input_mask)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)
                errG_total, G_logs = generator_loss(netsD, image_encoder,
                                                    fake_imgs, real_labels,
                                                    words_embs, sent_emb,
                                                    match_labels, class_ids,
                                                    self.opts)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                total_loss.append(errG_total.data.item())
                inception_scorer.predict(fake_imgs[-1], step)

        netG.train()
        for i in range(len(netsD)):
            netsD[i].train()

        m, s = inception_scorer.get_ic_score()
        return m, s, sum(total_loss) / val_batches
Beispiel #5
0
    def test(self, model_path, test_loader):
        batch_size = self.batch_size
        nz = self.opts.GAN.Z_DIM
        real_labels, fake_labels, match_labels = self.prepare_labels()

        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        noise, fixed_noise = noise.to(self.device), fixed_noise.to(self.device)

        text_encoder, netG = self.build_models_for_test(model_path)
        val_batches = len(test_loader)

        inception_scorer = InceptionScore(val_batches, batch_size, val_batches)
        total_loss = []
        with torch.no_grad():
            for step, data in enumerate(test_loader):
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                imgs, captions, class_ids, input_mask = prepare_data(
                    data, self.device)

                words_embs, sent_emb = self.text_encoder_forward(
                    text_encoder, captions, input_mask)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)
                inception_scorer.predict(fake_imgs[-1], step)

        m, s = inception_scorer.get_ic_score()
        return m, s, sum(total_loss) / val_batches
Beispiel #6
0
def train_epoch(model, criterion, train_iter, valid_iter, config):
    current_lr = config.lr

    lowest_valid_loss = np.inf
    no_improve_cnt = 0

    for epoch in range(1, config.n_epochs):
        optimizer = optim.SGD(model.parameters(), lr=current_lr)
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            batch = prepare_data(batch)

            current_batch_word_cnt = torch.sum(batch[1])
            # Most important lines in this method.
            # Since model takes BOS + sentence as an input and sentence + EOS as an output,
            # x(input) excludes last index, and y(index) excludes first index.
            x = batch[0][:, :-1]
            y = batch[0][:, 1:]
            # feed-forward
            hidden = model.init_hidden(config.batch_size)
            # print("hidden : ", hidden[0].shape, hidden[1].shape)
            # print("x : ", x.shape)
            # print("batch[1]", batch[1])
            y_hat = model(x, batch[1], hidden)

            # calcuate loss and gradients with back-propagation
            loss = get_loss(y, y_hat, criterion)

            # simple math to show stats
            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_word_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\tloss: %.4f\tPPL: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int((len(train_iter.dataset) // config.batch_size)),
                       avg_parameter_norm, avg_grad_norm, avg_loss,
                       np.exp(avg_loss), total_word_count // elapsed_time,
                       elapsed_time))

                total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
                start_time = time.time()

                train_loss = avg_loss

            # Another important line in this method.
            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch[0].size(0)
            if sample_cnt >= len(train_iter.dataset):
                break

        sample_cnt = 0
        total_loss, total_word_count = 0, 0

        model.eval()
        for batch_index, batch in enumerate(valid_iter):
            batch = prepare_data(batch)
            current_batch_word_cnt = torch.sum(batch[1])
            x = batch[0][:, :-1]
            y = batch[0][:, 1:]
            hidden = model.init_hidden(config.batch_size)

            y_hat = model(x, batch[1], hidden)

            loss = get_loss(y, y_hat, criterion, do_backward=False)

            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)

            sample_cnt += batch[0].size(0)
            if sample_cnt >= len(valid_iter.dataset):
                break

        avg_loss = total_loss / total_word_count
        print("valid loss: %.4f\tPPL: %.2f" % (avg_loss, np.exp(avg_loss)))

        if lowest_valid_loss > avg_loss:
            lowest_valid_loss = avg_loss
            no_improve_cnt = 0
        else:
            # decrease learing rate if there is no improvement.
            current_lr /= 10.
            no_improve_cnt += 1

        model.train()

        # model_fn = config.model.split(".")
        model_fn = config.model  # model name
        model_fn = [model_fn[:-1]] + [
            "%02d" % epoch,
            "%.2f-%.2f" % (train_loss, np.exp(train_loss)),
            "%.2f-%.2f" % (avg_loss, np.exp(avg_loss))
        ] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        torch.save(
            {
                "model": model.state_dict(),
                "config": config,
                "epoch": epoch + 1,
                "current_lr": current_lr
            }, ".".join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #7
0
    def __init__(self, dataset, opts, use_pretrained_embeddings=True):

        # TODO: Add Dropout layer later.
        self.dropout_keep_prob = tf.placeholder(tf.float32,
                                                name="dropout_keep_prob")

        if use_pretrained_embeddings:
            word2vec = get_word2vec_model(WORD2VEC_PATH)
            word2idx, idx2word, label2idx, idx2label = build_vocab(
                dataset.training_files,
                dataset.vocab_file,
                word2vec,
                min_counts=opts['min_counts'])
            embedding_weights = get_embedding_weights(word2idx, word2vec)
            embedding_length = embedding_weights.shape[1]
            # TODO: embedding might be trainable.
            self.embeddings = tf.Variable(embedding_weights,
                                          dtype=tf.float32,
                                          trainable=False)
        else:
            word2idx, idx2word, label2idx, idx2label = build_vocab(
                dataset.training_files,
                dataset.vocab_file,
                min_counts=opts['min_counts'])
            embedding_length = opts['embedding_length']
            self.embeddings = tf.Variable(tf.random_uniform(
                [len(word2idx), embedding_length], -1.0, 1.0),
                                          dtype=tf.float32)

        self.sess = tf.Session()

        self.enqueue_data, self.source, self.target_word, self.label, \
            self.sequence_length = prepare_data(self.sess, dataset.training_files, word2idx, label2idx, **opts)

        self.target_words_embedded = tf.nn.embedding_lookup(
            self.embeddings, self.target_word)
        self.sentences_embedded = tf.nn.embedding_lookup(
            self.embeddings, self.source)

        hidden_unit_size = opts['hidden_unit_size']
        num_senses = len(label2idx)

        encoder_cell = LSTMCell(hidden_unit_size)

        (encoder_fw_outputs, encoder_bw_outputs), (encoder_fw_final_state, encoder_bw_final_state) = \
            tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell, cell_bw=encoder_cell, inputs=self.sentences_embedded,
                                            sequence_length=self.sequence_length, dtype=tf.float32, time_major=True)

        encoder_final_state_c = tf.concat(
            (encoder_fw_final_state.c, encoder_bw_final_state.c), 1)
        encoder_final_state_h = tf.concat(
            (encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
        encoder_final_state = LSTMStateTuple(c=encoder_final_state_c,
                                             h=encoder_final_state_h)

        # self.encoder_target_embedding = encoder_final_state.c
        self.encoder_target_embedding = tf.concat(
            (encoder_final_state.c, self.target_words_embedded), 1)

        with tf.name_scope("output"):
            W = tf.Variable(tf.truncated_normal(
                [hidden_unit_size * 2 + embedding_length, num_senses],
                stddev=0.1),
                            name="W")
            b = tf.Variable(tf.constant(0.1, shape=[num_senses]), name="b")
            self.scores = tf.matmul(self.encoder_target_embedding, W) + b
            self.predictions = tf.argmax(self.scores, 1, name="predictions")

        with tf.name_scope('cross_entropy'):
            labels = tf.one_hot(self.label, num_senses)
            self.diff = tf.nn.softmax_cross_entropy_with_logits(
                labels=labels, logits=self.scores)

        with tf.name_scope('loss'):
            self.loss = tf.reduce_mean(self.diff)

        with tf.name_scope('train'):
            self.train_step = tf.train.AdamOptimizer(
                opts['learning_rate']).minimize(self.loss)

        with tf.name_scope('accuracy'):
            with tf.name_scope('correct_prediction'):
                correct_prediction = tf.equal(self.predictions,
                                              tf.argmax(labels, 1))
            with tf.name_scope('accuracy'):
                self.accuracy = tf.reduce_mean(
                    tf.cast(correct_prediction, tf.float32))

        self.sess.run(tf.global_variables_initializer())
Beispiel #8
0
def get_batch_manager(id_to_tag, tag_to_id, text, word_to_id):
    test_file = get_test_data2(text)
    test_data = prepare_data(test_file, word_to_id, tag_to_id, FLAGS.word_max_len)
    # load data,迭代器
    test_manager = BatchManager(test_data, len(id_to_tag), FLAGS.word_max_len, FLAGS.valid_batch_size)
    return test_manager
Beispiel #9
0
    def train(self, dataloader, image_encoder, text_encoder, optimizer, epoch,
              ixtoword, image_dir, batch_size):

        image_encoder.train()
        text_encoder.train()

        s_total_loss0 = 0
        s_total_loss1 = 0
        w_total_loss0 = 0
        w_total_loss1 = 0

        count = (epoch + 1) * len(dataloader)
        start_time = time.time()

        s_epoch_loss = 0
        w_epoch_loss = 0
        num_batches = len(dataloader)

        for step, data in enumerate(dataloader, 0):
            print('step', step)
            optimizer.zero_grad()

            imgs, captions, class_ids, input_mask = prepare_data(
                data, self.device)

            # words_features: batch_size x nef x 17 x 17
            # sent_code: batch_size x nef

            words_features, sent_code = image_encoder(imgs[-1])
            # --> batch_size x nef x 17*17

            batch_size, nef, att_size, _ = words_features.shape
            # words_features = words_features.view(batch_size, nef, -1)

            words_emb, sent_emb = self.text_enc_forward(
                text_encoder, captions, input_mask)
            labels = Variable(torch.LongTensor(range(batch_size))).to(
                self.device)

            w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                     labels, class_ids,
                                                     batch_size)
            w_total_loss0 += w_loss0.data.item()
            w_total_loss1 += w_loss1.data.item()
            loss = w_loss0 + w_loss1
            w_epoch_loss += w_loss0.item() + w_loss1.item()

            s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels,
                                         class_ids, batch_size)
            loss += s_loss0 + s_loss1
            s_total_loss0 += s_loss0.data.item()
            s_total_loss1 += s_loss1.data.item()
            s_epoch_loss += s_loss0.item() + s_loss1.item()

            loss.backward(retain_graph=True)

            # `clip_grad_norm` helps prevent
            # the exploding gradient problem in RNNs / LSTMs.
            if self.opts.TEXT.ENCODER != 'bert':
                torch.nn.utils.clip_grad_norm_(text_encoder.parameters(),
                                               self.opts.TRAIN.RNN_GRAD_CLIP)
            optimizer.step()

            if step != 0 and step % self.update_interval == 0:
                count = epoch * len(dataloader) + step

                s_cur_loss0 = s_total_loss0 / self.update_interval
                s_cur_loss1 = s_total_loss1 / self.update_interval

                w_cur_loss0 = w_total_loss0 / self.update_interval
                w_cur_loss1 = w_total_loss1 / self.update_interval

                elapsed = time.time() - start_time
                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                    's_loss {:5.2f} {:5.2f} | '
                    'w_loss {:5.2f} {:5.2f}'.format(
                        epoch, step, len(dataloader),
                        elapsed * 1000. / self.update_interval, s_cur_loss0,
                        s_cur_loss1, w_cur_loss0, w_cur_loss1))
                s_total_loss0 = 0
                s_total_loss1 = 0
                w_total_loss0 = 0
                w_total_loss1 = 0
                start_time = time.time()

            if step == num_batches - 1:
                # attention Maps
                img_set, _ = build_super_images(imgs[-1].cpu(),
                                                captions,
                                                ixtoword,
                                                attn_maps,
                                                att_size,
                                                None,
                                                batch_size,
                                                max_word_num=18)
                if img_set is not None:
                    im = Image.fromarray(img_set)
                    fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                    im.save(fullpath)
        s_epoch_loss /= len(dataloader)
        w_epoch_loss /= len(dataloader)
        return count, s_epoch_loss, w_epoch_loss
Beispiel #10
0
    def sampling(self, split_dir, model_path):

        text_encoder, netG = self.build_models_for_test(model_path)

        batch_size = self.batch_size
        nz = self.opts.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
        noise = noise.cuda()

        # the path to save generated images
        save_dir = os.path.join(self.output_dir, "samples")
        make_dir(save_dir)

        cnt = 0

        for _ in range(1):  # (opts.TEXT.CAPTIONS_PER_IMAGE):
            for step, data in enumerate(self.train_loader, 0):
                cnt += batch_size
                if step % 100 == 0:
                    print('step: ', step)
                if step > 50:
                    break

                imgs, captions, class_ids, input_mask = prepare_data(
                    data, self.device)

                words_embs, sent_emb = self.text_encoder_forward(
                    text_encoder, captions, input_mask)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask)

                for j in range(batch_size):
                    cap = captions[j].data.cpu().numpy()
                    name = "%d_%d" % (step, j)
                    s_tmp = '%s/single/%s' % (save_dir, name)
                    folder = s_tmp[:s_tmp.rfind('/')]
                    if not os.path.isdir(folder):
                        print('Make a new folder: ', folder)
                        make_dir(folder)

                    sentence = []
                    for m in range(len(cap)):
                        if cap[m] == 0:
                            break
                        word = self.ixtoword[cap[m]].encode(
                            'ascii', 'ignore').decode('ascii')
                        sentence.append(word)
                        sentence.append(' ')

                    print(name, ''.join(sentence))

                    k = -1
                    # for k in range(len(fake_imgs)):
                    im = fake_imgs[k][j].data.cpu().numpy()
                    # [-1, 1] --> [0, 255]
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    im = np.transpose(im, (1, 2, 0))
                    im = Image.fromarray(im)
                    fullpath = '%s_s%d.png' % (s_tmp, k)
                    im.save(fullpath)
Beispiel #11
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = self.opts.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        noise, fixed_noise = noise.to(self.device), fixed_noise.to(self.device)

        gen_iterations = 0

        lr_schedulers = []
        if self.use_lr_scheduler:
            for i in range(len(optimizersD)):
                lr_scheduler = LambdaLR(optimizersD[i],
                                        lr_lambda=lambda epoch: 0.998**epoch)

                for m in range(start_epoch):
                    lr_scheduler.step()
                lr_schedulers.append(lr_scheduler)

        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.train_loader)
            step = 0

            for i in range(len(lr_schedulers)):
                lr_schedulers[i].step()

            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = next(data_iter)
                imgs, captions, class_ids, captions_mask = prepare_data(
                    data, self.device)

                words_embs, sent_emb = self.text_encoder_forward(
                    text_encoder, captions, captions_mask)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, class_ids, self.opts)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 10 == 0:
                    print("Epoch: " + str(epoch) + " Step: " + str(step) +
                          " " + D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 300 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          epoch,
                                          step,
                                          name='average')
                    load_params(netG, backup_para)

            is_mean, is_std, error_G_val = self.validate(
                netG, netsD, text_encoder, image_encoder)
            self.val_logger.write("{} {} {}\n".format(epoch, is_mean, is_std))
            self.val_logger.flush()

            self.losses_logger.write("{} {} {}\n".format(
                epoch, errG_total.data.item(), error_G_val))
            self.losses_logger.flush()

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches,
                   errD_total.data.item(), errG_total.data.item(),
                   end_t - start_t))

            print("IS: {} {}".format(is_mean, is_std))
            if epoch % self.opts.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)