def robustness_exp(dataset, max_tree_depth, forest_depth, num_trees, num_experiments=1,
                      score='accuracy', weight=None, n_splits=10, use_agreement=False, delta=1):

    kf_scores = []
    kf = KFold(n_splits=n_splits)
    x, y, X_test, y_test = datasets.prepare_data(dataset, return_test=True)
    c = datasets.get_number_of_classes(dataset)
    score_func, score_metric = metrics.get_socre_foncs(score)

    trees = []
    for train, test in kf.split(x):
        X_train, _, y_train, _ = x[train], x[test], y[train], y[test]
        X_train, X_val, y_train, y_val = datasets.prepare_val(X_train, y_train)
        k_scores = []
        for k in range(num_experiments):
            rf, _, f_med, f_all, f_m = compress_tree(num_trees, max_tree_depth, forest_depth, X_train, y_train,
                                                     c, weight=weight, X_val=X_val, y_val=y_val,
                                                     score=score_metric, delta=delta)

            k_scores.append(score_func(rf, None, f_med, f_all, f_m, X_train, y_train, X_test, y_test))

        trees.append((rf, f_m, f_all, f_med))
        kf_scores.append(metrics.average_scores(k_scores, num_experiments))

    means = metrics.mean_and_std(kf_scores, mean_only=True)
    output = metrics.agreement_score(trees, X_test) if use_agreement else None
    return output, means
示例#2
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, class_ids, keys, \
                wrong_caps, wrong_caps_len, wrong_cls_id = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1)

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1)

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    return s_cur_loss, w_cur_loss
def generalization_exp(dataset, max_tree_depth, forest_depth, num_trees, num_experiments=1,
                      score='accuracy', weight=None, n_splits=10, delta=1):

    kf_scores = []
    kf = KFold(n_splits=n_splits)
    x, y, _, _, = datasets.prepare_data(dataset, return_test=False)
    c = datasets.get_number_of_classes(dataset)
    score_func, score_metric = metrics.get_socre_foncs(score)

    for k in range(num_experiments):
        f_scores = []
        for train, test in kf.split(x):
            X_train, X_test, y_train, y_test = x[train], x[test], y[train], y[test]
            X_train, X_val, y_train, y_val = datasets.prepare_val(X_train, y_train)
            rf, _, f_med, f_all, f_m = compress_tree(num_trees, max_tree_depth, forest_depth, X_train, y_train,
                                                     c, weight=weight, X_val=X_val, y_val=y_val,
                                                     score=score_metric, delta=delta)

            f_scores.append(score_func(rf, None, f_med, f_all, f_m, X_train, y_train, X_test, y_test))

        mean_var_win = metrics.mean_and_std(f_scores, mean_only=False)
        kf_scores.append(mean_var_win)

    print('\nFinal results:')
    print(f'Average RF mean {sum([score[0] for score in kf_scores]) / num_experiments}, var {sum([score[1] for score in kf_scores]) / num_experiments}')
    idx = 2
    for t in ('BM', 'VT', 'MED'):
        t_mean = sum([score[idx] for score in kf_scores]) / num_experiments
        t_wins = sum([score[idx + 2] for score in kf_scores]) / num_experiments
        idx += 3
        print(f'Average {t} mean {t_mean}, wins {t_wins}')

    return
示例#4
0
def robustness_exp(max_tree_depth, num_experiments, dataset, score='accuracy',
                   weight=None, n_splits=10, device='cpu', use_agreement=False, delta=1):

    kf_scores = []
    kf = KFold(n_splits=n_splits)
    x, y, X_test, y_test = datasets.prepare_data(dataset, return_test=True)

    score_func, score_metric = metrics.get_socre_foncs(score)
    c = datasets.get_number_of_classes(dataset)

    trees = []
    for train, test in kf.split(x):
        X_train, _, y_train, _ = x[train], x[test], y[train], y[test]
        X_train, X_val, y_train, y_val = datasets.prepare_val(X_train, y_train)
        X_nn = np.concatenate([X_train, X_val], axis=0)
        y_nn = np.concatenate([y_train, y_val], axis=0)
        k_scores = []
        nn_score = []
        for k in range(num_experiments):
            p_ic, nn_test_score = train_oracle_and_predict(dataset, X_nn, y_nn, X_test, y_test, c, device)
            f_med, f_voting, f_m = compress_tree(max_tree_depth, X_train, y_train, p_ic, weight=weight, X_val=X_val,
                                                 y_val=y_val, score=score_metric, delta=delta)
            k_scores.append(score_func(None, None, f_med, f_voting, f_m, X_train, y_train, X_test, y_test))
            nn_score.append(nn_test_score)

        trees.append((f_m, f_voting, f_med))
        kf_scores.append(metrics.average_scores(k_scores, num_experiments))

    means = metrics.mean_and_std(kf_scores, mean_only=True, show_rf=False)
    output = metrics.agreement_score(trees, X_test) if use_agreement else None

    nn_av = np.mean(nn_score)
    return output, means, nn_av
示例#5
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step

    return s_cur_loss, w_cur_loss
示例#6
0
def crembo_sklearn_example():
    # set arguments
    dataset = 'dermatology'
    args = {
        'dataset': 'iris',
        'num_trees': 100,
        'tree_depth': 4,
        'forest_depth': 10,
        'weight': 'balanced',
        'sklearn': True
    }

    # Create train, test, val sets
    x, y, X_test, y_test = datasets.prepare_data(dataset, return_test=True)
    X_train, X_val, y_train, y_val = datasets.prepare_val(x, y)
    train_loader = (X_train, y_train)
    test_loader = (X_test, y_test)
    val_loader = (X_val, y_val)

    # train large model
    M = RandomForestClassifier(n_estimators=args['num_trees'],
                               max_depth=args['forest_depth'],
                               class_weight=args['weight'])
    M.fit(X_train, y_train)

    # define create_model method
    # Here we create a tree model. All sklearn models should be wrapped by a class that
    # inherits from the MCSkLearnConsistensy class.
    create_model_func = MCConsistentTree(depth=args['tree_depth'],
                                         class_weight=args['weight']).get_clone

    # define train_hypothesis method
    train_hypothesis_func = train_sklearn

    # define evel_model method
    eval_model_func = eval_sklearn

    # initiate CREMBO class
    crembo = CREMBO(create_model_func,
                    train_hypothesis_func,
                    eval_model_func,
                    args,
                    delta=1)

    # run crembo
    f = crembo(M, train_loader, test_loader, val_loader, device=None)

    # print scores
    f_score = eval_model_func(f, test_loader, None)
    M_score = eval_model_func(M, test_loader, None)
    print(f'M score: {M_score}, CREMBO: {f_score}')

    return f
示例#7
0
    def scatter(self, inputs, kwargs, device_ids):
        # More like scatter and data prep at the same time. The point is we prep the data in such a way
        # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs.
        devices = ['cuda:' + str(x) for x in device_ids] if args.cuda else None
        splits = prepare_data(inputs[0],
                              devices,
                              allocation=args.batch_alloc,
                              batch_size=args.batch_size,
                              is_cuda=args.cuda,
                              train_mode=True)

        return [[split[device_idx] for split in splits] for device_idx in range(len(devices))], \
               [kwargs] * len(devices)
示例#8
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size, writer, count,
             ixtoword, labels, image_dir):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)
        nef, att_sze = words_features.size(1), words_features.size(2)

        # hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).item()

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    writer.add_scalars(main_tag="eval_loss",
                       tag_scalar_dict={
                           's_loss': s_cur_loss,
                           'w_loss': w_cur_loss
                       },
                       global_step=count)
    # save a image
    # attention Maps
    img_set, _ = \
        build_super_images(real_imgs[-1][:,:3].cpu(), captions,
                           ixtoword, attn, att_sze)
    if img_set is not None:
        im = Image.fromarray(img_set)
        fullpath = '%s/attention_maps_eval_%d.png' % (image_dir, count)
        im.save(fullpath)
        writer.add_image(tag="image_DAMSM_eval",
                         img_tensor=transforms.ToTensor()(im),
                         global_step=count)
    return s_cur_loss, w_cur_loss
示例#9
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0

    count = 0
    img_lst = []
    sent_lst = []
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        count += batch_size

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        img_lst.append(sent_code)
        sent_lst.append(sent_emb)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if count >= 1000:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    img_embs = torch.cat(img_lst)
    sent_embs = torch.cat(sent_lst)

    acc, pred = compute_topk(img_embs, sent_embs)

    logger.info(
        '| end epoch {:3d} | top-5 ({:4d}) {:5.2f} valid loss {:5.2f} {:5.2f} | lr {:.5f}|'
        .format(epoch, count, acc, s_cur_loss, w_cur_loss, lr))

    return s_cur_loss, w_cur_loss
示例#10
0
def test(dataloader, cnn_model, rnn_model, batch_size, labels, ixtoword,
         image_dir, input_channels):
    cnn_model.eval()
    rnn_model.eval()
    # s_total_loss0 = 0
    # s_total_loss1 = 0
    # w_total_loss0 = 0
    # w_total_loss1 = 0
    # count = (epoch + 1) * len(dataloader)
    # start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        # rnn_model.zero_grad()
        # cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])

        # --> batch_size x nef x 17*17
        nef, att_len, att_sze = words_features.size(1), words_features.size(
            2), words_features.size(3)
        # words_features = words_features.view(batch_size, nef, -1)

        # hidden = rnn_model.init_hidden(batch_size)
        hidden = rnn_model.init_hidden()
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)

        img_set, _ = \
            build_super_images(imgs[-1].cpu(), captions,
                               ixtoword, attn_maps, att_len, att_sze, input_channels)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/attention_maps%d.png' % (image_dir, step)
            # fullpath = '%s/attention_maps%d.png' % (image_dir, count)
            im.save(fullpath)

    # return count
    return step
示例#11
0
def sampling(text_encoder, netG, dataloader, device):

    model_dir = cfg.TRAIN.NET_G
    split_dir = 'valid'
    # Build and load the generator
    netG.load_state_dict(torch.load('models/%s/netG.pth' % (cfg.CONFIG_NAME)))
    netG.eval()

    batch_size = cfg.TRAIN.BATCH_SIZE
    s_tmp = model_dir
    save_dir = '%s/%s' % (s_tmp, split_dir)
    mkdir_p(save_dir)
    cnt = 0
    for i in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
        for step, data in enumerate(dataloader, 0):
            imags, captions, cap_lens, class_ids, keys = prepare_data(data)
            cnt += batch_size
            if step % 100 == 0:
                print('step: ', step)
            # if step > 50:
            #     break
            hidden = text_encoder.init_hidden(batch_size)
            # words_embs: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
            #######################################################
            # (2) Generate fake images
            ######################################################
            with torch.no_grad():
                noise = torch.randn(batch_size, 100)
                noise = noise.to(device)
                fake_imgs = netG(noise, sent_emb)
            for j in range(batch_size):
                s_tmp = '%s/single/%s' % (save_dir, keys[j])
                folder = s_tmp[:s_tmp.rfind('/')]
                if not os.path.isdir(folder):
                    print('Make a new folder: ', folder)
                    mkdir_p(folder)
                im = fake_imgs[j].data.cpu().numpy()
                # [-1, 1] --> [0, 255]
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                im = np.transpose(im, (1, 2, 0))
                im = Image.fromarray(im)
                fullpath = '%s_%3d.png' % (s_tmp, i)
                im.save(fullpath)
示例#12
0
def evaluate(dataloader, cnn_model, nlp_model, text_encoder_type, batch_size):
    cnn_model.eval()
    nlp_model.eval()
    text_encoder_type = text_encoder_type.casefold()
    assert text_encoder_type in (
        'rnn',
        'transformer',
    )
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data( data )

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        if text_encoder_type == 'rnn':
            hidden = nlp_model.init_hidden(batch_size)
            words_emb, sent_emb = nlp_model(captions, cap_lens, hidden)
        elif text_encoder_type == 'transformer':
            words_emb = nlp_model(captions)[0].transpose(1, 2).contiguous()
            sent_emb = words_emb[:, :, -1].contiguous()
            # sent_emb = sent_emb.view(batch_size, -1)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss( sent_code, sent_emb, labels, class_ids, batch_size )
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step

    return s_cur_loss, w_cur_loss
示例#13
0
def evaluate(dataloader, cnn_model, rnn_model, batch_size, labels):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, class_ids = prepare_data(data)
        words_features, sent_code = cnn_model(real_imgs[-1])
        if step == len(dataloader)-1:
          batch_size = len(subset_val)-(len(dataloader)-1)*batch_size
          labels = Variable(torch.LongTensor(range(batch_size)))
          labels = labels.cuda()          
        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data
        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data
    s_cur_loss = s_total_loss.item() / len(dataloader)
    w_cur_loss = w_total_loss.item() / len(dataloader)
    return s_cur_loss, w_cur_loss
示例#14
0
def generalization_exp(max_tree_depth, num_experiments=1, dataset='mnist',
              score='accuracy', weight=None, device='cpu', delta=1, n_splits=10):

    kf_scores = []
    kf = KFold(n_splits=n_splits)
    x, y, _, _,  = datasets.prepare_data(dataset, return_test=False)

    c = datasets.get_number_of_classes(dataset)
    score_func, score_metric = metrics.get_socre_foncs(score)

    nn_score = []
    for k in range(num_experiments):
        print(f'Experiment number {k+1}')
        f_scores = []
        for train, test in kf.split(x):
            X_train, X_test, y_train, y_test = x[train], x[test], y[train], y[test]
            X_train, X_val, y_train, y_val = datasets.prepare_val(X_train, y_train)
            X_nn = np.concatenate([X_train, X_val], axis=0)
            y_nn = np.concatenate([y_train, y_val], axis=0)

            p_ic, nn_test_score = train_oracle_and_predict(dataset, X_nn, y_nn, X_test, y_test, c, device)

            f_med, f_all, f_m = compress_tree(max_tree_depth, X_train, y_train, p_ic, weight=weight, X_val=X_val,
                                                 y_val=y_val, score=score_metric, delta=delta)
            f_scores.append(score_func(None, None, f_med, f_all, f_m, X_train, y_train, X_test, y_test))
            nn_score.append(nn_test_score)

        mean_var_win = metrics.mean_and_std(f_scores, mean_only=False, show_rf=False, nn_score=nn_score)
        kf_scores.append(mean_var_win)

    print('\nFinal results:')
    print(f'Average NN mean {sum([score[0] for score in kf_scores]) / num_experiments}, std {sum([score[1] for score in kf_scores]) / num_experiments}')
    idx = 2
    for t in ('BM', 'VT', 'MED'):
        t_mean = sum([score[idx] for score in kf_scores]) / num_experiments
        t_wins = sum([score[idx + 2] for score in kf_scores]) / num_experiments
        idx += 3
        print(f'Average {t} mean {t_mean}, wins {t_wins}')

    return
示例#15
0
    def train(self):
        torch.autograd.set_detect_anomaly(True)

        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
            batch_sizes = self.batch_size
            noise, local_noise, fixed_noise = [], [], []
            for batch_size in batch_sizes:
                noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE))
                local_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE))
                fixed_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE))
        else:
            batch_size = self.batch_size[0]
            noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE)
            local_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)
            fixed_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE)

        for epoch in range(start_epoch, self.max_epoch):
            logger.info("Epoch nb: %s" % epoch)
            gen_iterations = 0
            if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                data_iter = []
                for _idx in range(len(self.data_loader)):
                    data_iter.append(iter(self.data_loader[_idx]))
                total_batches_left = sum([len(self.data_loader[i]) for i in range(len(self.data_loader))])
                current_probability = [len(self.data_loader[i]) for i in range(len(self.data_loader))]
                current_probability_percent = [current_probability[i] / float(total_batches_left) for i in
                                               range(len(current_probability))]
            else:
                data_iter = iter(self.data_loader)

            _dataset = tqdm(range(self.num_batches))
            for step in _dataset:
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    subset_idx = np.random.choice(range(len(self.data_loader)), size=None,
                                                  p=current_probability_percent)
                    total_batches_left -= 1
                    if total_batches_left > 0:
                        current_probability[subset_idx] -= 1
                        current_probability_percent = [current_probability[i] / float(total_batches_left) for i in
                                                       range(len(current_probability))]

                    max_objects = subset_idx
                    data = data_iter[subset_idx].next()
                else:
                    data = data_iter.next()
                    max_objects = 3
                _dataset.set_description('Obj-{}'.format(max_objects))

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                with torch.no_grad():
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        hidden = text_encoder.init_hidden(batch_sizes[subset_idx])
                    else:
                        hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                    mask = (captions == 0).bool()
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    noise[subset_idx].data.normal_(0, 1)
                    local_noise[subset_idx].data.normal_(0, 1)
                    inputs = (noise[subset_idx], local_noise[subset_idx], sent_emb, words_embs, mask, transf_matrices,
                              transf_matrices_inv, label_one_hot, max_objects)
                else:
                    noise.data.normal_(0, 1)
                    local_noise.data.normal_(0, 1)
                    inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv,
                              label_one_hot, max_objects)

                inputs = tuple((inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor) else inp) for inp in inputs)
                fake_imgs, _, mu, logvar = netG(*inputs)

                #######################################################
                # (3) Update D network
                ######################################################
                # errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels[subset_idx], fake_labels[subset_idx],
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv, cfg=cfg,
                                                  max_objects=max_objects)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels,
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv, cfg=cfg,
                                                  max_objects=max_objects)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                # step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                netG.zero_grad()
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    errG_total = \
                        generator_loss(netsD, image_encoder, fake_imgs, real_labels[subset_idx],
                                       words_embs, sent_emb, match_labels[subset_idx], cap_lens, class_ids,
                                       local_labels=label_one_hot, transf_matrices=transf_matrices,
                                       transf_matrices_inv=transf_matrices_inv, max_objects=max_objects)
                else:
                    errG_total = \
                        generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                       words_embs, sent_emb, match_labels, cap_lens, class_ids,
                                       local_labels=label_one_hot, transf_matrices=transf_matrices,
                                       transf_matrices_inv=transf_matrices_inv, max_objects=max_objects)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(p.data, alpha=0.001)

                if cfg.TRAIN.EMPTY_CACHE:
                    torch.cuda.empty_cache()

                # save images
                if (
                        2 * gen_iterations == self.num_batches
                        or 2 * gen_iterations + 1 == self.num_batches
                        or gen_iterations + 1 == self.num_batches
                ):
                    logger.info('Saving images...')
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        self.save_img_results(netG, fixed_noise[subset_idx], sent_emb,
                                              words_embs, mask, image_encoder,
                                              captions, cap_lens, epoch, transf_matrices_inv,
                                              label_one_hot, local_noise[subset_idx], transf_matrices,
                                          max_objects, subset_idx, name='average')
                    else:
                        self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, transf_matrices_inv,
                                          label_one_hot, local_noise, transf_matrices,
                                          max_objects, None, name='average')
                    load_params(netG, backup_para)

            self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
示例#16
0
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer,
          epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb,
                                                 labels, cap_lens, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'.format(
                      epoch, step, len(dataloader),
                      elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0,
                      s_cur_loss1, w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count
示例#17
0
def compute_ppl(evaluator,
                space='smart',
                num_samples=100000,
                eps=1e-4,
                net='vgg'):
    """Perceptual Path Length: PyTorch implementation of the `PPL` class in
       https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
    """
    assert space in ['z', 'w', 'smart']
    assert net in ['vgg', 'alex']

    ppl_loss_fn = lpips.LPIPS(net=net, lpips=True)
    ppl_loss_fn.cuda()

    text_encoder, netG = evaluator.build_models_eval(init_func=weights_init)
    if space == 'smart':
        space = 'w' if cfg.GAN.B_STYLEGEN else 'z'
    if space == 'w':
        init_res = (
            netG.h_net1.init_layer.shape[-2],
            netG.h_net1.init_layer.shape[-1],
        )
        upscale_fctr = netG.h_net1.upsampler.scale_factor
        res_init_layers = [
            int(np.rint(r * upscale_fctr**((n - n % 2) // 2)))
            for n, r in enumerate(init_res * 5)
        ]
        res_2G = int(np.rint(netG.h_net2.res))
        res_3G = int(np.rint(netG.h_net3.res))

    batch_size = evaluator.batch_size
    nz = cfg.GAN.Z_DIM
    with torch.no_grad():
        z_code01 = Variable(torch.FloatTensor(batch_size * 2, nz))
        z_code01 = z_code01.cuda()
        t = Variable(torch.FloatTensor(batch_size, 1))
        t = t.cuda()

    ppls = []
    dl_itr = iter(evaluator.data_loader)
    # for step, data in enumerate( evaluator.data_loader, 0 ):
    pbar = tqdm(range(num_samples // batch_size), dynamic_ncols=True)
    for step in pbar:
        try:
            data = next(dl_itr)
        except StopIteration:
            dl_itr = iter(evaluator.data_loader)
            data = next(dl_itr)

        if step % 100 == 0:
            pbar.set_description('step: {}'.format(step))

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        #######################################################
        # (1) Extract text embeddings
        ######################################################
        if evaluator.text_encoder_type == 'rnn':
            hidden = text_encoder.init_hidden(batch_size)
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
        elif evaluator.text_encoder_type == 'transformer':
            words_embs = text_encoder(captions)[0].transpose(1, 2).contiguous()
            sent_emb = words_embs[:, :, -1].contiguous()
        # words_embs: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
        mask = (captions == 0)
        num_words = words_embs.size(2)
        if mask.size(1) > num_words:
            mask = mask[:, :num_words]

        #######################################################
        # (2) Generate fake images
        ######################################################
        t.data.uniform_(0, 1)
        z_code01.data.normal_(0, 1)
        c_code, _, _ = netG.ca_net(sent_emb)
        sent_emb = torch.cat((
            sent_emb,
            sent_emb,
        ), 0).detach()
        words_embs = torch.cat((
            words_embs,
            words_embs,
        ), 0).detach()
        mask = torch.cat((
            mask,
            mask,
        ), 0).detach()
        if space == 'w':
            # Control out the StyleGAN noise, as we are trying to measure feature interpolability only in PPL
            netG.noise_net1 = [
                torch.randn(batch_size,
                            1,
                            res,
                            res,
                            dtype=torch.float32,
                            device=z_code01.device) for res in res_init_layers
            ]
            netG.noise_net1 = [
                torch.cat((
                    noise,
                    noise,
                ), 0).detach() for noise in netG.noise_net1
            ]
            netG.noise_net2 = [
                torch.randn(batch_size,
                            1,
                            res_2G,
                            res_2G,
                            dtype=torch.float32,
                            device=z_code01.device) for _ in range(2)
            ]
            netG.noise_net2 = [
                torch.cat((
                    noise,
                    noise,
                ), 0).detach() for noise in netG.noise_net2
            ]
            netG.noise_net3 = [
                torch.randn(batch_size,
                            1,
                            res_3G,
                            res_3G,
                            dtype=torch.float32,
                            device=z_code01.device) for _ in range(2)
            ]
            netG.noise_net3 = [
                torch.cat((
                    noise,
                    noise,
                ), 0).detach() for noise in netG.noise_net3
            ]
            w_code01 = netG.map_net(z_code01, torch.cat((
                c_code,
                c_code,
            ), 0))
            w_code0, w_code1 = w_code01[0::2], w_code01[1::2]
            w_code0_lerp = lerp(w_code0, w_code1, t)
            w_code1_lerp = lerp(w_code0, w_code1, t + eps)
            w_code_lerp = torch.cat((
                w_code0_lerp,
                w_code1_lerp,
            ), 0).detach()
            fake_imgs01, _, _, _ = netG(w_code_lerp,
                                        sent_emb,
                                        words_embs,
                                        mask,
                                        is_dlatent=True)
        else:
            z_code0, z_code1 = z_code01[0::2], z_code01[1::2]
            z_code0_slerp = slerp(z_code0, z_code1, t)
            z_code1_slerp = slerp(z_code0, z_code1, t + eps)
            z_code_slerp = torch.cat((
                z_code0_slerp,
                z_code1_slerp,
            ), 0).detach()
            fake_imgs01, _, _, _ = netG(z_code_slerp, sent_emb, words_embs,
                                        mask)

        fake_imgs01 = fake_imgs01[-1]

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        if fake_imgs01.shape[2] > 256:
            factor = fake_imgs01.shape[2] // 256
            fake_imgs01 = torch.reshape(fake_imgs01, [
                -1, fake_imgs01.shape[1], fake_imgs01.shape[2] // factor,
                factor, fake_imgs01.shape[3] // factor, factor
            ])
            fake_imgs01 = torch.mean(fake_imgs01, dim=(
                3,
                5,
            ), keepdim=False)

        # # Scale dynamic range to [-1,1] for the lpips VGG.
        # fake_imgs01 = (fake_imgs01 + 1) * (255 / 2)
        # fake_imgs01.clamp_( 0, 255 )
        fake_imgs01.clamp_(-1., 1.)

        fake_imgs0, fake_imgs1 = fake_imgs01[:batch_size], fake_imgs01[
            batch_size:]
        # fake_imgs0, fake_imgs1 = fake_imgs01[0::2], fake_imgs01[1::2]

        # Evaluate perceptual distances.
        ppls.append(
            ppl_loss_fn.forward(fake_imgs0,
                                fake_imgs1).squeeze().detach().cpu().numpy() *
            (1 / 1e-4**2))

    ppls = np.concatenate(ppls, axis=0)

    # Reject outliers.
    lo = np.percentile(ppls, 1, interpolation='lower')
    hi = np.percentile(ppls, 99, interpolation='higher')
    ppls = np.extract(np.logical_and(lo <= ppls, ppls <= hi), ppls)

    return np.mean(ppls)
示例#18
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                batch_t_begin = time.time()
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, color_ids,sleeve_ids,gender_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                D_logs_cls = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    imgs[i] = gaussian_to_input(imgs[i]) ## INSTANCE NOISE
                    errD, cls_D= discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels, class_ids, color_ids, sleeve_ids, gender_ids)
                    # backward and update parameters
                    errD_both = errD + cls_D/3.
                    # backward and update parameters
                    errD_both.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    errD_total += cls_D/3.0
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)
                    D_logs_cls += 'clsD%d: %.2f ' % (i, cls_D.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, color_ids,sleeve_ids, gender_ids,imgs)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    batch_t_end = time.time()
                    print('| epoch {:3d} | {:5d}/{:5d} batches | batch_timer: {:5.2f} | '
                          .format(epoch, step, self.num_batches,
                                  batch_t_end - batch_t_begin,))
                    print(D_logs + '\n' + D_logs_cls + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data, errG_total.data,
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
示例#19
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for models is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            idx = 0  ###

            avg_ddva = 0
            for _ in range(1):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                    # Generate images for human-text ----------------------------------------------------------------
                    data_human = [captions, cap_lens, misc]

                    imgs, captions, cap_lens, class_ids, keys, wrong_caps,\
                                wrong_caps_len, wrong_cls_id= prepare_data(data_human)

                    hidden = text_encoder.init_hidden(batch_size)
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask)

                    # Generate images for imperfect caption-text-------------------------------------------------------
                    data_imperfect = [
                        imperfect_captions, imperfect_cap_lens, misc
                    ]

                    imgs, imperfect_captions, imperfect_cap_lens, class_ids, imperfect_keys, wrong_caps,\
                                wrong_caps_len, wrong_cls_id = prepare_data(data_imperfect)

                    hidden = text_encoder.init_hidden(batch_size)
                    words_embs, sent_emb = text_encoder(
                        imperfect_captions, imperfect_cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (imperfect_captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    noise.data.normal_(0, 1)
                    imperfect_fake_imgs, _, _, _ = netG(
                        noise, sent_emb, words_embs, mask)

                    # Sort the results by keys to align ----------------------------------------------------------------
                    keys, captions, cap_lens, fake_imgs, _, _ = sort_by_keys(
                        keys, captions, cap_lens, fake_imgs, None, None)

                    imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, true_imgs, _ = \
                                sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs,\
                                             imgs, None)

                    # Shift device for the imgs, target_imgs and imperfect_imgs------------------------------------------------
                    for i in range(len(imgs)):
                        imgs[i] = imgs[i].to(secondary_device)
                        imperfect_fake_imgs[i] = imperfect_fake_imgs[i].to(
                            secondary_device)
                        fake_imgs[i] = fake_imgs[i].to(secondary_device)

                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_imgs[k][j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))

                        cap_im = imperfect_fake_imgs[k][j].data.cpu().numpy()
                        cap_im = (cap_im + 1.0) * 127.5
                        cap_im = cap_im.astype(np.uint8)
                        cap_im = np.transpose(cap_im, (1, 2, 0))

                        # Uncomment to scale true image
                        true_im = true_imgs[k][j].data.cpu().numpy()
                        true_im = (true_im + 1.0) * 127.5
                        true_im = true_im.astype(np.uint8)
                        true_im = np.transpose(true_im, (1, 2, 0))

                        # Uncomment to save images.
                        #true_im = Image.fromarray(true_im)
                        #fullpath = '%s_true_s%d.png' % (s_tmp, idx)
                        #true_im.save(fullpath)
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        im.save(fullpath)
                        #cap_im = Image.fromarray(cap_im)
                        #fullpath = '%s_imperfect_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        #cap_im.save(fullpath)

                    neg_ddva = negative_ddva(
                        imperfect_fake_imgs,
                        imgs,
                        fake_imgs,
                        reduce='mean',
                        final_only=True).data.cpu().numpy()
                    avg_ddva += neg_ddva * (-1)

                    #text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in captions.tolist()]

                    #imperfect_text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in
                    #                       imperfect_captions.tolist()]

                    print(step)
            avg_ddva = avg_ddva / (step + 1)
            print('\n\nAvg_DDVA: ', avg_ddva)
示例#20
0
    def train(self):
        text_encoder, image_encoder, netG, target_netG, netsD, start_epoch, style_loss = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()

                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]

                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]

                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys, i_wrong_caps,\
                            i_wrong_caps_len, i_wrong_cls_id = prepare_data(data_imperfect)

                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(
                    imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(
                ), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)

                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]

                # Move tensors to the secondary device.
                noise = noise.to(secondary_device
                                 )  # IMPORTANT! We are reusing the same noise.
                i_sent_emb = i_sent_emb.to(secondary_device)
                i_words_embs = i_words_embs.to(secondary_device)
                i_mask = i_mask.to(secondary_device)

                # Generate images.
                imperfect_fake_imgs, _, _, _ = target_netG(
                    noise, i_sent_emb, i_words_embs, i_mask)

                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [
                    sent_emb, real_labels, fake_labels, words_embs, class_ids,
                    w_words_embs, wrong_caps_len, wrong_cls_id
                ]

                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)

                sent_emb, real_labels, fake_labels, words_embs, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id = \
                            sorted_bag

                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)

                #-----------------------------------------------------------------------------------------------------------

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, style_loss, imgs)
                kl_loss = KL_loss(mu, logvar)

                errG_total += kl_loss

                G_logs += 'kl_loss: %.2f ' % kl_loss

                # Shift device for the imgs and target_imgs.-----------------------------------------------------
                for i in range(len(imgs)):
                    imgs[i] = imgs[i].to(secondary_device)
                    fake_imgs[i] = fake_imgs[i].to(secondary_device)

                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)
                neg_ddva *= 10.  # Scale so that the ddva score is not overwhelmed by other losses.
                errG_total += neg_ddva.to(cfg.GPU_ID)
                G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva
                #------------------------------------------------------------------------------------------------

                errG_total.backward()

                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # Copy parameters to the target network.
                if gen_iterations % 20 == 0:
                    load_params(target_netG, copy_G_params(netG))

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f neg_ddva: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, neg_ddva, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
示例#21
0
    def embedding(self, split_dir, model):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            if cfg.GPU_ID != -1:
                netG.cuda()
            netG.eval()
            #
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            print(img_encoder_path)
            print('Load image encoder from:', img_encoder_path)
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            if cfg.GPU_ID != -1:
                image_encoder = image_encoder.cuda()
            image_encoder.eval()

            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            if cfg.GPU_ID != -1:
                text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM

            with torch.no_grad():
                noise = Variable(torch.FloatTensor(batch_size, nz))
                if cfg.GPU_ID != -1:
                    noise = noise.cuda()

            # the path to save generated images
            save_dir = model_dir[:model_dir.rfind('.pth')]

            cnt = 0

            # new
            if cfg.TRAIN.CLIP_SENTENCODER:
                print("Use CLIP SentEncoder for sampling")
            img_features = dict()
            txt_features = dict()

            with torch.no_grad():
                for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                    for step, data in enumerate(self.data_loader, 0):
                        cnt += batch_size
                        if step % 100 == 0:
                            print('step: ', step)

                        imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                            data)

                        hidden = text_encoder.init_hidden(batch_size)
                        # words_embs: batch_size x nef x seq_len
                        # sent_emb: batch_size x nef
                        words_embs, sent_emb = text_encoder(
                            captions, cap_lens, hidden)
                        words_embs, sent_emb = words_embs.detach(
                        ), sent_emb.detach()
                        mask = (captions == 0)
                        num_words = words_embs.size(2)
                        if mask.size(1) > num_words:
                            mask = mask[:, :num_words]

                        if cfg.TRAIN.CLIP_SENTENCODER:

                            # random select one paragraph for each training example
                            sents = []
                            for idx in range(len(texts)):
                                sents_per_image = texts[idx].split(
                                    '\n')  # new 3/11
                                if len(sents_per_image) > 1:
                                    sent_ix = np.random.randint(
                                        0,
                                        len(sents_per_image) - 1)
                                else:
                                    sent_ix = 0
                                sents.append(sents_per_image[0])
                            # print('sents: ', sents)

                            sent = clip.tokenize(sents)  # .to(device)

                            # load clip
                            #model = torch.jit.load("model.pt").cuda().eval()
                            sent_input = sent
                            if cfg.GPU_ID != -1:
                                sent_input = sent.cuda()
                            # print("text input", sent_input)
                            sent_emb_clip = model.encode_text(
                                sent_input).float()
                            if CLIP:
                                sent_emb = sent_emb_clip
                        #######################################################
                        # (2) Generate fake images
                        ######################################################
                        noise.data.normal_(0, 1)
                        fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                                  mask)
                        if CLIP:
                            images = []
                            for j in range(fake_imgs[-1].shape[0]):
                                image = fake_imgs[-1][j].cpu().clone()
                                image = image.squeeze(0)
                                unloader = transforms.ToPILImage()
                                image = unloader(image)

                                image = preprocess(
                                    image.convert("RGB"))  # 256*256 -> 224*224
                                images.append(image)

                            image_mean = torch.tensor(
                                [0.48145466, 0.4578275, 0.40821073]).cuda()
                            image_std = torch.tensor(
                                [0.26862954, 0.26130258, 0.27577711]).cuda()

                            image_input = torch.tensor(np.stack(images)).cuda()
                            image_input -= image_mean[:, None, None]
                            image_input /= image_std[:, None, None]
                            cnn_codes = model.encode_image(image_input).float()
                        else:
                            region_features, cnn_codes = image_encoder(
                                fake_imgs[-1])
                        for j in range(batch_size):
                            cnn_code = cnn_codes[j]

                            temp = keys[j].replace('b', '').replace("'", '')
                            img_features[temp] = cnn_code.cpu().numpy()
                            txt_features[temp] = sent_emb[j].cpu().numpy()
            with open(save_dir + ".pkl", 'wb') as f:
                pickle.dump(img_features, f)
            with open(save_dir + "_text.pkl", 'wb') as f:
                pickle.dump(txt_features, f)
示例#22
0
    def sampling(self, split_dir, model):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            if cfg.GPU_ID != -1:
                netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            if cfg.GPU_ID != -1:
                text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM

            with torch.no_grad():
                noise = Variable(torch.FloatTensor(batch_size, nz))
                if cfg.GPU_ID != -1:
                    noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0

            #new
            if cfg.TRAIN.CLIP_SENTENCODER:
                print("Use CLIP SentEncoder for sampling")

            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)
                    # if step > 50:
                    #     break

                    #imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
                    #new
                    imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                        data)

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    # new
                    if cfg.TRAIN.CLIP_SENTENCODER:

                        # random select one paragraph for each training example
                        sents = []
                        for idx in range(len(texts)):
                            sents_per_image = texts[idx].split(
                                '\n')  # new 3/11
                            if len(sents_per_image) > 1:
                                sent_ix = np.random.randint(
                                    0,
                                    len(sents_per_image) - 1)
                            else:
                                sent_ix = 0
                            sents.append(sents_per_image[sent_ix])
                            with open('%s/%s' % (save_dir, 'eval_sents.txt'),
                                      'a+') as f:
                                f.write(sents_per_image[sent_ix] + '\n')
                        # print('sents: ', sents)

                        sent = clip.tokenize(sents)  # .to(device)

                        # load clip
                        #model = torch.jit.load("model.pt").cuda().eval()
                        sent_input = sent
                        if cfg.GPU_ID != -1:
                            sent_input = sent.cuda()
                        # print("text input", sent_input)
                        with torch.no_grad():
                            sent_emb = model.encode_text(sent_input).float()

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask)
                    for j in range(batch_size):
                        s_tmp = '%s/fake/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                            print('Make a new folder: ', f'{save_dir}/real')
                            mkdir_p(f'{save_dir}/real')
                            print('Make a new folder: ', f'{save_dir}/text')
                            mkdir_p(f'{save_dir}/text')
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, k)
                        im.save(fullpath)
                        temp = keys[j].replace('b', '').replace("'", '')
                        shutil.copy(f"../data/Face/images/{temp}.jpg",
                                    f"{save_dir}/real/")
                        shutil.copy(f"../data/Face/text/{temp}.txt",
                                    f"{save_dir}/text/")
示例#23
0
    def sampling(self, split_dir, num_samples=30000):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz))
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict["netG"])
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0

            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 10000 == 0:
                        print('step: ', step)
                    if step >= num_samples:
                        break

                    imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                    transf_matrices_inv = transformation_matrices[1]

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                    with torch.no_grad():
                        fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)
                    for j in range(batch_size):
                        s_tmp = '%s/single/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, k)
                        im.save(fullpath)
示例#24
0
    def sampling(self, split_dir, num_samples=30000):
        if cfg.TRAIN.NET_G == '':
            logger.error('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.to(cfg.DEVICE)
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            text_encoder = text_encoder.to(cfg.DEVICE)
            text_encoder.eval()
            logger.info('Loaded text encoder from: %s', cfg.TRAIN.NET_E)

            batch_size = self.batch_size[0]
            nz = cfg.GAN.GLOBAL_Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz)).to(cfg.DEVICE)
            local_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            max_objects = 10
            logger.info('Load G from: %s', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1]
            save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir)
            mkdir_p(save_dir)
            logger.info("Saving images to: {}".format(save_dir))

            number_batches = num_samples // batch_size
            if number_batches < 1:
                number_batches = 1

            data_iter = iter(self.data_loader)

            for step in tqdm(range(number_batches)):
                data = data_iter.next()

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data(
                    data, eval=True)

                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                local_noise.data.normal_(0, 1)
                inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, max_objects)
                inputs = tuple((inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor) else inp) for inp in inputs)

                with torch.no_grad():
                    fake_imgs, _, mu, logvar = netG(*inputs)
                for batch_idx, j in enumerate(range(batch_size)):
                    s_tmp = '%s/%s' % (save_dir, keys[j])
                    folder = s_tmp[:s_tmp.rfind('/')]
                    if not os.path.isdir(folder):
                        logger.info('Make a new folder: %s', folder)
                        mkdir_p(folder)
                    k = -1
                    # for k in range(len(fake_imgs)):
                    im = fake_imgs[k][j].data.cpu().numpy()
                    # [-1, 1] --> [0, 255]
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    im = np.transpose(im, (1, 2, 0))
                    im = Image.fromarray(im)
                    fullpath = '%s_s%d.png' % (s_tmp, step*batch_size+batch_idx)
                    im.save(fullpath)
示例#25
0
def train(dataloader, netG, netD, text_encoder, optimizerG, optimizerD_enc,
          optimizerD_proj, state_epoch, batch_size, device, output_dir,
          logger):
    #i = -1
    for epoch in range(state_epoch + 1, cfg.TRAIN.MAX_EPOCH + 1):

        for step, data in enumerate(dataloader, 0):
            #i+=1
            imags, captions, cap_lens, class_ids, keys = prepare_data(data)
            hidden = text_encoder.init_hidden(batch_size)
            # words_embs: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

            imgs = imags[0].to(device)
            real_features = netD(imgs)
            output = netD.COND_DNET(real_features, sent_emb)
            errD_real = torch.nn.ReLU()(1.0 - output).mean()

            output = netD.COND_DNET(real_features[:(batch_size - 1)],
                                    sent_emb[1:batch_size])
            errD_mismatch = torch.nn.ReLU()(1.0 + output).mean()

            # synthesize fake images
            noise = torch.randn(batch_size, 100)
            noise = noise.to(device)
            fake = netG(noise, sent_emb)

            #if not cfg.TRAIN.ONLY_REAL:
            # G does not need update with D

            #if cfg.TRAIN.ONLY_REAL:
            #    for p in netD.COND_DNET.parameters():
            #        p.requires_grad_(False)

            fake_features = netD(fake.detach())
            errD_fake = netD.COND_DNET(fake_features, sent_emb)
            errD_fake = torch.nn.ReLU()(1.0 + errD_fake).mean()

            #if cfg.TRAIN.ONLY_REAL:
            #    for p in netD.COND_DNET.parameters():
            #        p.requires_grad_(True)

            errD_enc = errD_real + (errD_fake + errD_mismatch) / 2.0
            optimizerD_enc.zero_grad()
            optimizerD_proj.zero_grad()
            optimizerG.zero_grad()
            errD_enc.backward(retain_graph=True)
            optimizerD_enc.step()

            errD_proj = errD_real + errD_mismatch
            optimizerD_enc.zero_grad()
            optimizerD_proj.zero_grad()
            optimizerG.zero_grad()
            errD_proj.backward()
            optimizerD_proj.step()

            #MA-GP
            interpolated = (imgs.data).requires_grad_()
            sent_inter = (sent_emb.data).requires_grad_()
            features = netD(interpolated)
            out = netD.COND_DNET(features, sent_inter)
            grads = torch.autograd.grad(outputs=out,
                                        inputs=(interpolated, sent_inter),
                                        grad_outputs=torch.ones(
                                            out.size()).cuda(),
                                        retain_graph=True,
                                        create_graph=True,
                                        only_inputs=True)
            grad0 = grads[0].view(grads[0].size(0), -1)
            grad1 = grads[1].view(grads[1].size(0), -1)
            grad = torch.cat((grad0, grad1), dim=1)
            grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
            d_loss_gp = torch.mean((grad_l2norm)**6)
            d_loss = 2.0 * d_loss_gp

            optimizerD_enc.zero_grad()
            optimizerD_proj.zero_grad()
            optimizerG.zero_grad()
            d_loss.backward(retain_graph=True)
            optimizerD_enc.step()

            optimizerD_enc.zero_grad()
            optimizerD_proj.zero_grad()
            optimizerG.zero_grad()
            d_loss.backward()
            optimizerD_proj.step()

            # update G

            #if (i+1) % cfg.TRAIN.N_CRITIC == 0:
            #i = -1
            features = netD(fake)
            output = netD.COND_DNET(features, sent_emb)
            errG = -output.mean()
            optimizerG.zero_grad()
            optimizerD.zero_grad()
            errG.backward()
            optimizerG.step()

            logger.info(
                '[%d/%d][%d/%d] Loss_encD: %.3f Loss_projD: %.3f Loss_G %.3f errD_real %.3f errD_mis %.3f errD_fake %.3f magp %.3f'
                % (epoch, cfg.TRAIN.MAX_EPOCH, step, len(dataloader),
                   errD_enc.item(), errD_proj.item(), errG.item(),
                   errD_real.item(), errD_mismatch.item(), errD_fake.item(),
                   d_loss_gp.item()))

        vutils.save_image(fake.data,
                          '%s/imgs/fake_samples_epoch_%03d.png' %
                          (output_dir, epoch),
                          normalize=True)

        if epoch % 1 == 0:
            torch.save(netG.state_dict(),
                       '%s/models/netG_%03d.pth' % (output_dir, epoch))
            torch.save(netD.state_dict(),
                       '%s/models/netD_%03d.pth' % (output_dir, epoch))
            torch.save(optimizerG.state_dict(),
                       '%s/models/optimizerG.pth' % (output_dir))
            torch.save(optimizerD.state_dict(),
                       '%s/models/optimizerD.pth' % (output_dir))

    return
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()

            # load text encoder
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = torch.load(cfg.TRAIN.NET_E,
                                    map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            #load image encoder
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = torch.load(img_encoder_path,
                                    map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir,
                                    map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            R_count = 0
            R = np.zeros(30000)
            cont = True
            for ii in range(11):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                if (cont == False):
                    break
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if (cont == False):
                        break
                    if step % 100 == 0:
                        print('cnt: ', cnt)
                    # if step > 50:
                    #     break

                    imgs, captions, cap_lens, class_ids, keys = prepare_data(
                        data)

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask, cap_lens)
                    for j in range(batch_size):
                        s_tmp = '%s/single/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            #print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d_%d.png' % (s_tmp, k, ii)
                        im.save(fullpath)

                    _, cnn_code = image_encoder(fake_imgs[-1])

                    for i in range(batch_size):
                        mis_captions, mis_captions_len = self.dataset.get_mis_caption(
                            class_ids[i])
                        hidden = text_encoder.init_hidden(99)
                        _, sent_emb_t = text_encoder(mis_captions,
                                                     mis_captions_len, hidden)
                        rnn_code = torch.cat(
                            (sent_emb[i, :].unsqueeze(0), sent_emb_t), 0)
                        ### cnn_code = 1 * nef
                        ### rnn_code = 100 * nef
                        scores = torch.mm(cnn_code[i].unsqueeze(0),
                                          rnn_code.transpose(0, 1))  # 1* 100
                        cnn_code_norm = torch.norm(cnn_code[i].unsqueeze(0),
                                                   2,
                                                   dim=1,
                                                   keepdim=True)
                        rnn_code_norm = torch.norm(rnn_code,
                                                   2,
                                                   dim=1,
                                                   keepdim=True)
                        norm = torch.mm(cnn_code_norm,
                                        rnn_code_norm.transpose(0, 1))
                        scores0 = scores / norm.clamp(min=1e-8)
                        if torch.argmax(scores0) == 0:
                            R[R_count] = 1
                        R_count += 1

                    if R_count >= 30000:
                        sum = np.zeros(10)
                        np.random.shuffle(R)
                        for i in range(10):
                            sum[i] = np.average(R[i * 3000:(i + 1) * 3000 - 1])
                        R_mean = np.average(sum)
                        R_std = np.std(sum)
                        print("R mean:{:.4f} std:{:.4f}".format(R_mean, R_std))
                        cont = False
示例#27
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if i == 0:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus,
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, self.gpus,
                                   local_labels=label_one_hot, transf_matrices=transf_matrices,
                                   transf_matrices_inv=transf_matrices_inv)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # save images
                if gen_iterations % 1000 == 0:
                    print(D_logs + '\n' + G_logs)

                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch,  transf_matrices_inv,
                                          label_one_hot, name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)

        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
示例#28
0
    def genDiscOutputs(self, split_dir, num_samples=57140):
        if cfg.TRAIN.NET_G == '':
            logger.error('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.to(cfg.DEVICE)
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)  ###HACK
            state_dict = torch.load(cfg.TRAIN.NET_E,
                                    map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            text_encoder = text_encoder.to(cfg.DEVICE)
            text_encoder.eval()
            logger.info('Loaded text encoder from: %s', cfg.TRAIN.NET_E)

            batch_size = self.batch_size[0]
            nz = cfg.GAN.GLOBAL_Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz)).to(cfg.DEVICE)
            local_noise = Variable(
                torch.FloatTensor(batch_size,
                                  cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            for keys in state_dict.keys():
                print(keys)
            logger.info('Load G from: %s', model_dir)
            max_objects = 3
            from model import D_NET256
            netD = D_NET256()
            netD.load_state_dict(state_dict["netD"][2])

            netD.eval()

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1]
            save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir)
            mkdir_p(save_dir)
            logger.info("Saving images to: {}".format(save_dir))

            number_batches = num_samples // batch_size
            if number_batches < 1:
                number_batches = 1

            data_iter = iter(self.data_loader)
            real_labels, fake_labels, match_labels = self.prepare_labels()

            for step in tqdm(range(number_batches)):
                data = data_iter.next()

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data(
                    data, eval=True)

                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                local_noise.data.normal_(0, 1)
                inputs = (noise, local_noise, sent_emb, words_embs, mask,
                          transf_matrices, transf_matrices_inv, label_one_hot,
                          max_objects)
                inputs = tuple(
                    (inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor
                                                      ) else inp)
                    for inp in inputs)
                with torch.no_grad():
                    fake_imgs, _, mu, logvar = netG(*inputs)
                    inputs = (fake_imgs, fake_labels, transf_matrices,
                              transf_matrices_inv, max_objects)
                    codes = netsD[-1].partial_forward(*inputs)
示例#29
0
    def sample(self, split_dir, num_samples=25, draw_bbox=False):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils

        if cfg.TRAIN.NET_G == '':
            print('Error: the path for model NET_G is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = cfg.TRAIN.BATCH_SIZE
            nz = cfg.GAN.Z_DIM

            model_dir = cfg.TRAIN.NET_G
            state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG = G_NET()
            print('Load G from: ', model_dir)
            netG.apply(weights_init)

            netG.load_state_dict(state_dict["netG"])
            netG.cuda()
            netG.eval()

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s_%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)
            #######################################
            noise = Variable(torch.FloatTensor(9, nz))

            imsize = 256

            for step, data in enumerate(self.data_loader, 0):
                if step >= num_samples:
                    break

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, bbox = \
                    prepare_data(data, eval=True)
                transf_matrices_inv = transformation_matrices[1][0].unsqueeze(0)
                label_one_hot = label_one_hot[0].unsqueeze(0)

                img = imgs[-1][0]
                val_image = img.view(1, 3, imsize, imsize)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs[0].unsqueeze(0).detach(), sent_emb[0].unsqueeze(0).detach()
                words_embs = words_embs.repeat(9, 1, 1)
                sent_emb = sent_emb.repeat(9, 1)
                mask = (captions == 0)
                mask = mask[0].unsqueeze(0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                mask = mask.repeat(9, 1)
                transf_matrices_inv = transf_matrices_inv.repeat(9, 1, 1, 1)
                label_one_hot = label_one_hot.repeat(9, 1, 1)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                with torch.no_grad():
                    fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)

                data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0)
                data_img[0] = val_image
                data_img[1:10] = fake_imgs[-1]

                if draw_bbox:
                    for idx in range(3):
                        x, y, w, h = tuple([int(imsize*x) for x in bbox[0, idx]])
                        w = imsize-1 if w > imsize-1 else w
                        h = imsize-1 if h > imsize-1 else h
                        if x <= -1:
                            break
                        data_img[:10, :, y, x:x + w] = 1
                        data_img[:10, :, y:y + h, x] = 1
                        data_img[:10, :, y+h, x:x + w] = 1
                        data_img[:10, :, y:y + h, x + w] = 1

                # get caption
                cap = captions[0].data.cpu().numpy()
                sentence = ""
                for j in range(len(cap)):
                    if cap[j] == 0:
                        break
                    word = self.ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii')
                    sentence += word + " "
                sentence = sentence[:-1]
                vutils.save_image(data_img, '{}/{}_{}.png'.format(save_dir, sentence, step), normalize=True, nrow=10)
            print("Saved {} files to {}".format(step, save_dir))
示例#30
0
    def train(self, model):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )  #load encoder
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches

        if cfg.TRAIN.CLIP_SENTENCODER:
            print("CLIP Sentence Encoder: True")

        if cfg.TRAIN.CLIP_LOSS:
            print("CLIP Loss: True")

        if cfg.TRAIN.EXTRA_LOSS:
            print("Extra DAMSM Loss in G: True")
            print("DAMSM Weight: ", cfg.TRAIN.WEIGHT_DAMSM_LOSS)

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()

                # imgs, captions, cap_lens, class_ids, keys = prepare_data(data) #new sents:, sents
                # new: return raw texts
                imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                    data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                # new: rename
                words_embs_damsm, sent_emb_damsm = text_encoder(
                    captions, cap_lens, hidden)
                #print('captions shape from trainer: ', captions.shape) torch.Size([12, 18])
                #print('sentence emb size: ', sent_emb.shape) torch.Size([12, 256])
                words_embs_damsm, sent_emb_damsm = words_embs_damsm.detach(
                ), sent_emb_damsm.detach()
                #print('sentence emb size after detach: ', sent_emb[0]) torch.Size([12, 256])
                mask = (captions == 0)
                num_words = words_embs_damsm.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # new: use clip sentence encoder
                if cfg.TRAIN.CLIP_SENTENCODER or cfg.TRAIN.CLIP_LOSS:
                    sents = []
                    # randomly select one paragraph for each training example
                    for idx in range(len(texts)):
                        sents_per_image = texts[idx].split(
                            '\n')  #new: '\n' rather than '.'
                        if len(sents_per_image) > 1:
                            sent_ix = np.random.randint(
                                0,
                                len(sents_per_image) - 1)
                        else:
                            sent_ix = 0
                        sents.append(sents_per_image[sent_ix])
                    #print('sents: ', sents)

                    sent = clip.tokenize(sents)  #.to(device)

                    # load clip
                    #model = torch.jit.load("model.pt").cuda().eval()    # ViT-B/32
                    sent_input = sent.cuda()

                    with torch.no_grad():
                        sent_emb_clip = model.encode_text(sent_input).float()
                        if cfg.TRAIN.CLIP_SENTENCODER:
                            sent_emb = sent_emb_clip
                        else:
                            sent_emb = sent_emb_damsm
                else:
                    sent_emb_clip = 0
                    sent_emb = sent_emb_damsm

                words_embs = words_embs_damsm

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()

                # new: pass clip model and sent_emb_damsm for CLIP_LOSS = True
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                        words_embs, sent_emb, match_labels, cap_lens, class_ids, model, sent_emb_damsm, sent_emb_clip)

                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or epoch % 10 == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)