示例#1
0
    def test(self,
             session,
             step,
             summary_writer=None,
             print_rate=1,
             sample_dir=None,
             meta=None):

        feed_dict, original_sequence = self.get_feed_dict_and_orig(session)

        g_loss_pure, g_reg, d_loss_val, d_pen, rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo, summary = session.run(
            [
                self.g_cost_pure, self.gen_reg, self.d_cost, self.d_penalty,
                self.rmse_temp, self.rmse_cc, self.rmse_sh, self.rmse_sp,
                self.rmse_geo, self.summary_op
            ],
            feed_dict=feed_dict)

        summary_writer.add_summary(summary, step)
        original_sequence = original_sequence.reshape([
            1, self.frame_size, self.crop_size, self.crop_size, self.channels
        ])
        # print(original_sequence.shape)
        # images = zero state of weather
        # generate forecast from state zero
        forecast = session.run(self.sample, feed_dict=feed_dict)
        # return all rmse-s

        denorm_original_sequence = denormalize(original_sequence, self.wvars,
                                               self.crop_size, self.frame_size,
                                               self.channels, meta)
        denorm_forecast = denormalize(forecast, self.wvars, self.crop_size,
                                      self.frame_size, self.channels, meta)

        diff = []
        for orig, gen in zip(denorm_original_sequence, denorm_forecast):
            dif = orig - gen
            diff.append(dif[:, 1, :, :, :])

        if step % print_rate == 0:
            print(
                "Step: %d, generator loss: (%g + %g), discriminator_loss: (%g + %g)"
                % (step, g_loss_pure, g_reg, d_loss_val, d_pen))
            print("RMSE - Temp: %g, CC: %g, SH: %g, SP: %g, Geo: %g" %
                  (rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo))

            print('saving original')
            save_image(denorm_original_sequence, sample_dir,
                       'init_%d_image' % step)
            print('saving forecast / fakes')
            save_image(denorm_forecast, sample_dir, 'gen_%d_future' % step)

        rmse_all = [rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo]
        costs = [g_loss_pure, g_reg, d_loss_val, d_pen]

        return rmse_all, costs, diff
示例#2
0
    def train(self,
              session,
              step,
              summary_writer=None,
              log_summary=False,
              sample_dir=None,
              generate_sample=False,
              meta=None):
        if log_summary:
            start_time = time.time()

        critic_itrs = self.critic_iterations

        for critic_itr in range(critic_itrs):
            session.run(self.d_adam, feed_dict=self.get_feed_dict(session))

        feed_dict = self.get_feed_dict(session)
        session.run(self.g_adam_gan, feed_dict=feed_dict)
        session.run(self.g_adam_first, feed_dict=feed_dict)

        if log_summary:
            g_loss_pure, g_reg, d_loss_val, d_pen, rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo, fake_min, fake_max, summary = session.run(
                [self.g_cost_pure, self.gen_reg, self.d_cost, self.d_penalty, self.rmse_temp, self.rmse_cc, self.rmse_sh, self.rmse_sp, self.rmse_geo, self.fake_min, self.fake_max, self.summary_op],
                feed_dict=feed_dict)
            summary_writer.add_summary(summary, step)
            print("Time: %g/itr, Step: %d, generator loss: (%g + %g), discriminator_loss: (%g + %g)" % (
                time.time() - start_time, step, g_loss_pure, g_reg, d_loss_val, d_pen))
            print("RMSE - Temp: %g, CC: %g, SH: %g, SP: %g, Geo: %g" % (rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo))
            print("Fake_vid min: %g, max: %g" % (fake_min, fake_max))

        if generate_sample:
            original_sequence = session.run(self.videos)
            original_sequence = original_sequence.reshape([self.batch_size, self.frame_size, self.crop_size, self.crop_size, self.channels])
            print(original_sequence.shape)
            # images = zero state of weather
            images = original_sequence[:,0,:,:,:]
            # generate forecast from state zero
            forecast = session.run(self.sample, feed_dict={self.input_images: images})

            original_sequence = denormalize(original_sequence, self.wvars, self.crop_size, self.frame_size, self.channels, meta)
            print('saving original')
            save_image(original_sequence, sample_dir, 'init_%d_image' % step)

            forecast = denormalize(forecast, self.wvars, self.crop_size, self.frame_size, self.channels, meta)
            print('saving forecast / fakes')
            save_image(forecast, sample_dir, 'gen_%d_future' % step)
示例#3
0
    def train(self):
        device = self.device
        self.net.to(device)
        opts = {
            "title": 'train',
            "xlabel": 'times',
            "ylabel": 'loss',
            "legend": ['dloss', 'gloss']
        }
        count = 0
        for step in range(100):
            for img in self.data_loader:
                img = img.to(device)
                count += 1
                self.net.reset_grad()

                self.net.g_forward(img)
                self.net.d_forward()

                dloss = self.net.calc_d_loss()
                dloss.backward(retain_graph=True)

                gloss = self.net.calc_g_loss()
                gloss.backward()
                self.net.d_opt.step()

                self.net.D.frozen(False)
                self.net.g_opt.step()
                self.net.D.frozen(True)

                if not count % self.out_inv:
                    self.vis.line(
                        X=[self.out_inv * count],
                        Y=[[dloss.detach().cpu(),
                            gloss.detach().cpu()]],
                        update='append',
                        opts=opts,
                        win='training loss')
                    self.vis.images([
                        denormalize(self.net.raw_img[0]),
                        denormalize(
                            random.choice(self.net.add_watermark_img_trans)[0])
                    ],
                                    win='img')
def main(plot_dir, epoch):

    # read in pickle files
    glimpses = pickle.load(open(plot_dir + "g_{}.p".format(epoch), "rb"))
    locations = pickle.load(open(plot_dir + "l_{}.p".format(epoch), "rb"))

    glimpses = np.concatenate(glimpses)

    # grab useful params
    size = int(plot_dir.split('_')[2][0])
    num_anims = len(locations)
    num_cols = glimpses.shape[0]
    img_shape = glimpses.shape[1]

    # denormalize coordinates
    coords = [denormalize(img_shape, l) for l in locations]

    fig, axs = plt.subplots(nrows=1, ncols=num_cols)
    # fig.set_dpi(100)

    # plot base image
    for j, ax in enumerate(axs.flat):
        ax.imshow(glimpses[j], cmap="Greys_r")
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    def updateData(i):
        color = 'r'
        co = coords[i]
        for j, ax in enumerate(axs.flat):
            for p in ax.patches:
                p.remove()
            c = co[j]
            rect = bounding_box(c[0], c[1], size, color)
            ax.add_patch(rect)

    # animate
    anim = animation.FuncAnimation(fig,
                                   updateData,
                                   frames=num_anims,
                                   interval=500,
                                   repeat=True)

    # save as mp4
    name = plot_dir + 'epoch_{}.mp4'.format(epoch)
    anim.save(name, extra_args=['-vcodec', 'h264', '-pix_fmt', 'yuv420p'])
示例#5
0
    def get_images(self, net_student=None, targets=None):
        print("get_images call")

        net_teacher = self.net_teacher

        use_fp16 = self.use_fp16
        save_every = self.save_every

        kl_loss = nn.KLDivLoss(reduction='batchmean').cuda()

        local_rank = torch.cuda.current_device()

        best_cost = 1e4

        criterion = self.criterion

        # setup target labels
        if targets is None:
            #only works for classification now, for other tasks need to provide target vector
            targets = torch.LongTensor(
                [random.randint(0, 999) for _ in range(self.bs)]).to('cuda')
            if not self.random_label:
                # preselected classes, good for ResNet50v1.5
                targets = [
                    1, 933, 946, 980, 25, 63, 92, 94, 107, 985, 151, 154, 207,
                    250, 270, 277, 283, 292, 294, 309, 311, 325, 340, 360, 386,
                    402, 403, 409, 530, 440, 468, 417, 590, 670, 817, 762, 920,
                    949, 963, 967, 574, 487
                ]
                targets = torch.LongTensor(
                    targets * (int(self.bs / len(targets)))).to('cuda')

        img_original = self.image_resolution

        data_type = torch.half if use_fp16 else torch.float
        inputs = torch.randn((self.bs, 3, img_original, img_original),
                             requires_grad=True,
                             device='cuda',
                             dtype=data_type)

        pooling_function = nn.modules.pooling.AvgPool2d(kernel_size=2)

        if self.setting_id == 0:
            skipfirst = False
        else:
            skipfirst = True

        iteration = 0
        for lr_it, lower_res in enumerate([2, 1]):
            if lr_it == 0:
                iterations_per_layer = 2000
            else:
                iterations_per_layer = 1000 if not skipfirst else 2000
                if self.setting_id == 2:
                    iterations_per_layer = 20000

            if lr_it == 0 and skipfirst:
                continue

            lim_0, lim_1 = self.jitter // lower_res, self.jitter // lower_res

            if self.setting_id == 0:
                #multi resolution, 2k iterations with low resolution, 1k at normal, ResNet50v1.5 works the best, ResNet50 is ok
                optimizer = optim.Adam([inputs],
                                       lr=self.lr,
                                       betas=[0.5, 0.9],
                                       eps=1e-8)
                do_clip = True
            elif self.setting_id == 1:
                #2k normal resolultion, for ResNet50v1.5; Resnet50 works as well
                optimizer = optim.Adam([inputs],
                                       lr=self.lr,
                                       betas=[0.5, 0.9],
                                       eps=1e-8)
                do_clip = True
            elif self.setting_id == 2:
                #20k normal resolution the closes to the paper experiments for ResNet50
                optimizer = optim.Adam([inputs],
                                       lr=self.lr,
                                       betas=[0.9, 0.999],
                                       eps=1e-8)
                do_clip = False

            if use_fp16:
                static_loss_scale = 256
                static_loss_scale = "dynamic"
                _, optimizer = amp.initialize([],
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale=static_loss_scale)

            lr_scheduler = lr_cosine_policy(self.lr, 100, iterations_per_layer)

            for iteration_loc in range(iterations_per_layer):
                iteration += 1

                # learning rate scheduling
                lr_scheduler(optimizer, iteration_loc, iteration_loc)

                # perform downsampling if needed
                if lower_res != 1:
                    inputs_jit = pooling_function(inputs)
                else:
                    inputs_jit = inputs

                # apply random jitter offsets
                off1 = random.randint(-lim_0, lim_0)
                off2 = random.randint(-lim_1, lim_1)
                inputs_jit = torch.roll(inputs_jit,
                                        shifts=(off1, off2),
                                        dims=(2, 3))

                # Flipping
                flip = random.random() > 0.5
                if flip and self.do_flip:
                    inputs_jit = torch.flip(inputs_jit, dims=(3, ))

                # forward pass
                optimizer.zero_grad()
                net_teacher.zero_grad()

                outputs = net_teacher(inputs_jit)
                outputs = self.network_output_function(outputs)

                # R_cross classification loss
                loss = criterion(outputs, targets)

                # R_prior losses
                loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit)

                # R_feature loss
                loss_r_feature = sum(
                    [mod.r_feature for mod in self.loss_r_feature_layers])

                # R_ADI
                loss_verifier_cig = torch.zeros(1)
                if self.adi_scale != 0.0:
                    if self.detach_student:
                        outputs_student = net_student(inputs_jit).detach()
                    else:
                        outputs_student = net_student(inputs_jit)

                    T = 3.0
                    if 1:
                        T = 3.0
                        # Jensen Shanon divergence:
                        # another way to force KL between negative probabilities
                        P = nn.functional.softmax(outputs_student / T, dim=1)
                        Q = nn.functional.softmax(outputs / T, dim=1)
                        M = 0.5 * (P + Q)

                        P = torch.clamp(P, 0.01, 0.99)
                        Q = torch.clamp(Q, 0.01, 0.99)
                        M = torch.clamp(M, 0.01, 0.99)
                        eps = 0.0
                        loss_verifier_cig = 0.5 * kl_loss(
                            torch.log(P + eps), M) + 0.5 * kl_loss(
                                torch.log(Q + eps), M)
                        # JS criteria - 0 means full correlation, 1 - means completely different
                        loss_verifier_cig = 1.0 - torch.clamp(
                            loss_verifier_cig, 0.0, 1.0)

                    if local_rank == 0:
                        if iteration % save_every == 0:
                            print('loss_verifier_cig',
                                  loss_verifier_cig.item())

                # l2 loss on images
                loss_l2 = torch.norm(inputs_jit.view(self.bs, -1),
                                     dim=1).mean()

                # combining losses
                loss_aux = self.var_scale_l2 * loss_var_l2 + \
                           self.var_scale_l1 * loss_var_l1 + \
                           self.bn_reg_scale * loss_r_feature + \
                           self.l2_scale * loss_l2

                if self.adi_scale != 0.0:
                    loss_aux += self.adi_scale * loss_verifier_cig

                loss = self.main_loss_multiplier * loss + loss_aux

                if local_rank == 0:
                    if iteration % save_every == 0:
                        print("------------iteration {}----------".format(
                            iteration))
                        print("total loss", loss.item())
                        print("loss_r_feature", loss_r_feature.item())
                        print("main criterion",
                              criterion(outputs, targets).item())

                        if self.hook_for_display is not None:
                            self.hook_for_display(inputs, targets)

                # do image update
                if use_fp16:
                    # optimizer.backward(loss)
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()

                # clip color outlayers
                if do_clip:
                    inputs.data = clip(inputs.data, use_fp16=use_fp16)

                if best_cost > loss.item() or iteration == 1:
                    best_inputs = inputs.data.clone()

                if iteration % save_every == 0 and (save_every > 0):
                    if local_rank == 0:
                        vutils.save_image(
                            inputs,
                            '{}/best_images/output_{:05d}_gpu_{}.png'.format(
                                self.prefix, iteration // save_every,
                                local_rank),
                            normalize=True,
                            scale_each=True,
                            nrow=int(10))

        if self.store_best_images:
            best_inputs = denormalize(best_inputs)
            self.save_images(best_inputs, targets)

        # to reduce memory consumption by states of the optimizer we deallocate memory
        optimizer.state = collections.defaultdict(dict)
示例#6
0
def over_sampling(X, T, Y, params_over):
    gen_lr = params_over['gen_lr']
    dis_lr = params_over['dis_lr']
    batch_size = params_over['batch_size']
    epochs = params_over['epochs']
    latent_dim = params_over['noise_size']
    major_multiple = params_over['major_multiple']
    minor_ratio = params_over['minor_ratio']
    loss_type = params_over['loss_type']
    seed = 1234
    max_loop = 10
    fake_multiple = 5

    init_tf(seed)

    train_df = pd.concat([X, T, Y], axis=1).copy()
    out_df = train_df.copy()

    n_samples = pd.Series(num_class(train_df, 'Y', 'T'))
    print('Initial samples:', n_samples.tolist())

    num_major = n_samples.max() * major_multiple
    idx_major = n_samples.argmax()
    num_minor = num_major * minor_ratio

    n_rest_samples = [num_minor] * len(n_samples)
    n_rest_samples[idx_major] = num_major
    n_rest_samples = pd.Series(n_rest_samples).round().astype('int32')
    n_rest_samples -= n_samples
    n_rest_samples[n_rest_samples < 0] = 0
    num_fake_data = n_rest_samples.sum() * fake_multiple
    print('Initial rest samples:', n_rest_samples.tolist())

    train_df, normalize_vars = normalize(train_df)
    data_dim = train_df.shape[1]

    generator, discriminator, combined = \
        build_gan_network(gen_lr, dis_lr, data_dim, latent_dim, loss_type)
    train(train_df, epochs, batch_size, latent_dim, generator, discriminator,
          combined)

    global stored_discriminator
    stored_discriminator = discriminator

    for _ in range(max_loop):
        if n_rest_samples.sum() == 0:
            break
        noise = np.random.normal(0, 1, (num_fake_data, latent_dim))
        gen_data = generator.predict(noise)
        gen_df = pd.DataFrame(gen_data, columns=train_df.columns)
        gen_df = denormalize(gen_df, normalize_vars)
        gen_df = gen_df.round()

        tr, tn, cr, cn = num_class(gen_df, 'Y', 'T')
        print('Generated data (tr, tn, cr, cn):', tr, tn, cr, cn)

        gen_df_list = split_class(gen_df, 'Y', 'T')
        for idx, df in enumerate(gen_df_list):
            n_sel_samples = df.shape[0] if df.shape[0] < n_rest_samples[
                idx] else n_rest_samples[idx]
            n_rest_samples[idx] -= n_sel_samples
            sel_df = gen_df.iloc[:n_sel_samples]
            out_df = pd.concat([out_df, sel_df])

        print('Rest samples:', n_rest_samples.tolist())

    out_df = out_df.reset_index(drop=True).sample(frac=1)
    X = out_df.drop(['T', 'Y'], axis=1)
    T = out_df['T']
    Y = out_df['Y']

    return X, T, Y