コード例 #1
0
ファイル: TimeGAN.py プロジェクト: popescuaaa/gan-playground
    def train_system(self):
        for epoch in range(self.num_epochs):
            for batch_idx, real in enumerate(self.dl):
                stock, dt = real

                stock = stock.view(*stock.shape, 1)
                stock = stock.float()
                stock = stock.to(self.device)

                mean = stock.mean()

                dt = dt.view(*dt.shape, 1)
                dt = dt.float()
                dt = dt.to(self.device)

                noise_stock = self.dist_latent.sample(
                    sample_shape=(self.batch_size, self.seq_len,
                                  self.g_dim_latent))
                noise_stock = noise_stock.to(self.device)

                loss_g, _ = self.train_generator(noise_stock, dt, mean)
                loss_d, fake = self.train_discriminator(
                    noise_stock, stock, dt, mean)

                if batch_idx == 0:
                    print(
                        f"Epoch [{epoch}/{self.num_epochs}] Batch {batch_idx}/{len(self.dl)} \
                          Loss D: {loss_d:.4f}, loss G: {loss_g:.4f}")

                    wandb.log({
                        'epoch':
                        epoch,
                        'd loss':
                        loss_d,
                        'g loss':
                        loss_g,
                        'mae':
                        mean_absolute_error(
                            stock[0].view(-1).cpu().numpy(),
                            fake[0].view(-1).detach().cpu().numpy()),
                        # torch.abs(fake - stock).mean().item() => mae
                        'Conditional on deltas fake sample':
                        plot_time_series(
                            fake[0].view(-1).detach().cpu().numpy(),
                            '[Conditional (on deltas)] Fake sample'),
                        'Real sample':
                        plot_time_series(stock[0].view(-1).cpu().numpy(),
                                         '[Corresponding] Real sample')
                    })
コード例 #2
0
    def train_system(self):
        for epoch in range(self.num_epochs):
            for batch_idx, real in enumerate(self.dl):
                real_data, dt = real

                real_data = real_data.view(*real_data.shape, 1)
                real_data = real_data.float()
                real_data = real_data.to(self.device)

                dt = dt.view(*dt.shape, 1)
                dt = dt.float()
                dt = dt.to(self.device)

                noise = self.dist_latent.sample(sample_shape=(self.batch_size, self.seq_len, self.g_dim_latent))
                noise = noise.to(self.device)

                g_iter = self.g_iter
                d_iter = self.d_iter

                for var_name in range(g_iter):
                    loss_g, _ = self.train_generator(noise, real_data, dt)

                for var_name in range(d_iter):
                    loss_d, fake = self.train_discriminator(noise, real_data, dt)

                if batch_idx == 0:
                    print(
                        f"Epoch [{epoch}/{self.num_epochs}] Batch {batch_idx}/{len(self.dl)} \
                          Loss D: {loss_d.detach().cpu().item():.4f}, loss G: {loss_g.detach().cpu().item():.4f}"
                    )

                    # Visualize the whole distribution
                    rd, gd = self.generate_distributions()
                    fig = visualization_metrics.visualize(gd, rd)

                    wandb.log({
                        'epoch': epoch,
                        'd loss': loss_d.detach().cpu().item(),
                        'g loss': loss_g.detach().cpu().item(),
                        'mae': mean_absolute_error(real_data[0].view(-1).cpu().numpy(),
                                                   fake[0].view(-1).detach().cpu().numpy()),
                        'Conditional on deltas fake sample': plot_time_series(
                            fake[0].view(-1).detach().cpu().numpy(),
                            '[Conditional (on deltas)] Fake sample'),
                        'Real sample': plot_time_series(
                            real_data[0].view(-1).cpu().numpy(),
                            '[Corresponding] Real sample'),
                        'Distribution': fig
                    })
コード例 #3
0
def analyse_owd(test_id='', out_dir='', replot_only='0', source_filter='',
                min_values='3', omit_const='0', ymin='0', ymax='0',
                lnames='', stime='0.0', etime='0.0', out_name='', pdf_dir='',
                ts_correct='1',plot_params='', plot_script='', burst_sep='0.0',
                sburst='1', eburst='0', seek_window='', anchor_map='', owd_midpoint='0'):
    "Plot OWD of flows"
    
    # Note we allow ts_correct as a parameter for syntactic similarity to other
    # analyse_* tasks, but abort with warning if user tries explicitly to
    # make it 0 (which is unacceptable for OWD calculations)
    
    if ts_correct == '0':
        abort("Warning: Cannot do OWD calculations with ts_correct=0")

    (test_id_arr, 
    out_files, 
    out_groups) = _extract_owd_pktloss(test_id, out_dir, replot_only, 
                                    source_filter, ts_correct,
                                    burst_sep, sburst, eburst,
                                    seek_window, log_loss='0',
                                    anchor_map=anchor_map,
                                    owd_midpoint = owd_midpoint)
        
    (out_files, out_groups) = filter_min_values(out_files, out_groups, min_values)
    out_name = get_out_name(test_id_arr, out_name)
 
    burst_sep = float(burst_sep)
    if burst_sep == 0.0:
        plot_time_series(out_name, out_files, 'OWD (ms)', 2, 1000.0, 'pdf',
                     out_name + '_owd', pdf_dir=pdf_dir, omit_const=omit_const,
                     ymin=float(ymin), ymax=float(ymax), lnames=lnames,
                     stime=stime, etime=etime, groups=out_groups, plot_params=plot_params,
                     plot_script=plot_script, source_filter=source_filter)
    else:
        # Each trial has multiple files containing data from separate bursts detected within the trial
        plot_incast_ACK_series(out_name, out_files, 'OWD (ms)', 2, 1000.0, 'pdf',
                        out_name + '_owd', pdf_dir=pdf_dir, aggr='',
                        omit_const=omit_const, ymin=float(ymin), ymax=float(ymax),
                        lnames=lnames, stime=stime, etime=etime, groups=out_groups, burst_sep=burst_sep,
                        sburst=int(sburst), plot_params=plot_params, plot_script=plot_script,
                        source_filter=source_filter)


    # done
    puts('\n[MAIN] COMPLETED plotting OWDs %s \n' % out_name)
コード例 #4
0
def analyse_pktloss(test_id='', out_dir='', replot_only='0', source_filter='',
                min_values='3', omit_const='0', ymin='0', ymax='0',
                lnames='', stime='0.0', etime='0.0', out_name='', pdf_dir='',
                ts_correct='1', plot_params='', plot_script='', burst_sep='0.0',
                sburst='1', eburst='0', seek_window='', log_loss='2'):
    "Plot per-flow packet loss events vs time (or cumlative over time)"
    
    if log_loss != '1' and log_loss !='2':
        abort("Must set log_loss=1 (pkt loss events) or log_loss=2 (cumulative pkt loss)")
        
    (test_id_arr, 
    out_files, 
    out_groups) = _extract_owd_pktloss(test_id, out_dir, replot_only, 
                                    source_filter, ts_correct,
                                    burst_sep, sburst, eburst,
                                    seek_window, log_loss)
        
    (out_files, out_groups) = filter_min_values(out_files, out_groups, min_values)
    out_name = get_out_name(test_id_arr, out_name)
 
    burst_sep = float(burst_sep)
    if burst_sep == 0.0:
        plot_time_series(out_name, out_files, 'Lost packets', 2, 1, 'pdf',
                     out_name + '_loss2', pdf_dir=pdf_dir, omit_const=omit_const,
                     ymin=float(ymin), ymax=float(ymax), lnames=lnames,
                     stime=stime, etime=etime, groups=out_groups, plot_params=plot_params,
                     plot_script=plot_script, source_filter=source_filter)
    else:
        # Each trial has multiple files containing data from separate bursts detected within the trial
        plot_incast_ACK_series(out_name, out_files, 'Lost packets', 2, 1, 'pdf',
                        out_name + '_loss2', pdf_dir=pdf_dir, aggr='',
                        omit_const=omit_const, ymin=float(ymin), ymax=float(ymax),
                        lnames=lnames, stime=stime, etime=etime, groups=out_groups, burst_sep=burst_sep,
                        sburst=int(sburst), plot_params=plot_params, plot_script=plot_script,
                        source_filter=source_filter)


    # done
    puts('\n[MAIN] COMPLETED plotting pktloss %s \n' % out_name)