예제 #1
0
        },
        ep,
    )

    show_idx = np.random.randint(0, 8)
    fig = plt.figure()
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)
    ax1.imshow(img[show_idx].permute(1, 2, 0).detach().cpu().numpy() * 0.225 +
               0.5)
    ax2.imshow(lab[show_idx].detach().cpu().numpy())
    ax3.imshow(mask[show_idx].detach().cpu().numpy())
    ax3.imshow(th_pred[show_idx].detach().cpu().numpy())

    if show_plt:
        plt.show()
    else:
        writer.add_figure("TrainFig", fig, ep, close=True)

    if test_miou.value()[0] >= best_mIU:
        best_mIU = test_miou.value()[0]
        state = {
            'state_one': net_one.state_dict(),
            'EPOCH_start': ep,
            'mIU': best_mIU
        }
        print("Saving model")
        torch.save(state, save_path)
예제 #2
0
class GPUWorker(object):
    ' Object exists on each GPU and handles individual GPU training '

    def __init__(self, device, pop, batches, top, trunc, mut, model_flags,
                 maxgen):
        ' Constructor downloads parameters and allocates memory for models and data'
        self.device = torch.device(device)
        self.num_models = pop
        self.num_batches = batches
        self.models = []
        self.train_data = []
        self.test_data = []
        self.val_data = []
        self.mut = mut
        self.max_gen = maxgen
        self.flags = model_flags

        # Set trunc threshold to integer
        if top == 0:
            self.trunc_threshold = int(trunc * pop)
        else:
            self.trunc_threshold = top

        self.elite_eval = torch.zeros(self.trunc_threshold, device=self.device)
        (y, m, d, hr, min, s, x1, x2, x3) = time.localtime(time.time())
        #self.writer = SummaryWriter("results/P{}_G{}_tr{}_{}{}{}_{}{}{}".format(pop, maxgen, self.trunc_threshold,y,m,d,
        #                                                                        hr,min,s))

        self.writer = SummaryWriter(
            "results/Iris_m{}_P{}_t{}_{}{}{}_{}{}{}".format(
                mut, pop, top, y, m, d, hr, min, s))

        'Model generation. Created on cpu, moved to gpu, ref stored on cpu'
        for i in range(pop):
            #self.models.append(Forward(model_flags).cuda(self.device))
            self.models.append(model(model_flags).cuda(self.device))

        'Data set Storage'
        train_data_loader, test_data_loader, val_data_loader = datareader.read_data(
            x_range=[i for i in range(0, 4)],
            y_range=[i for i in range(4, 7)],
            geoboundary=[20, 200, 20, 100],
            batch_size=0,
            set_size=batches,
            normalize_input=True,
            data_dir='./',
            test_ratio=0.2)

        for (geometry, spectra) in train_data_loader:
            self.train_data.append(
                (geometry.to(self.device), spectra.to(self.device)))

        for (geometry, spectra) in test_data_loader:
            self.test_data.append(
                (geometry.to(self.device), spectra.to(self.device)))

        for (geometry, spectra) in val_data_loader:
            self.val_data.append(
                (geometry.to(self.device), spectra.to(self.device)))

        ' Load in best_model.pt and start a population of its mutants '
        """
        with torch.no_grad():
            rand_mut = self.collect_random_mutations()
            self.models[0] = torch.load('best_model.pt', map_location=self.device)
            self.models[0].eval()

            m_t = self.models[0]
            for i in range(1, pop):
                m = self.models[i]
                for (mp, m_tp, mut) in zip(m.parameters(), m_t.parameters(), rand_mut):
                    mp.copy_(m_tp).add_(mut[i])
                m.eval()
        """

        'GPU Tensor that stores fitness values & sorts from population. Would Ideally store in gpu shared memory'
        self.fit = torch.zeros(pop, device=self.device)
        self.sorted = torch.zeros(pop, device=self.device)
        self.hist_plot = []
        self.lorentz_plot = []

    def run(self, gen):
        ' Method manages the run on a single gpu '

        with torch.no_grad():

            'Queue every calculation of fitness for each model on GPU. Doing this all at once with all models already loaded' \
            'Might cause slowdown due to lack of memory. Apparently only 16 or so kernels can execute at once, which I had not' \
            'realized.'
            # Generate array of random indices corresponding to each batch for each model
            rand_batch = np.random.randint(self.num_batches,
                                           size=self.num_models)
            for j in range(self.num_models):
                g, s = self.train_data[rand_batch[j]]
                self.models[j].eval()
                fwd, w0, wp, g_n = self.models[j](g)
                self.fit[j] = Forward.fitness_f(
                    fwd, s)  #- self.custom_loss(w0,wp,g_n,gen)
                #self.BC.append(lorentz_model.bc_func(bc))

            ' Wait for every kernel queued to execute '
            torch.cuda.synchronize(self.device)

            ' Get sorting based off of fitness array, this function is gpu optimized '
            # Sorted is a tensor of indices organized from best to worst model
            self.sorted = torch.argsort(self.fit, descending=True)
            self.writer.add_scalar('training loss', -self.fit[self.sorted[0]],
                                   gen)
            g, s = self.val_data[0]
            fwd, w0, wp, g_n = self.models[self.sorted[0]](g)
            self.writer.add_scalar('validation loss',
                                   -Forward.fitness_f(fwd, s), gen)
            #.writer.add_scalar('validation loss', -Forward.fitness_f(fwd, s) + self.custom_loss(w0, wp, g_n, gen),gen)

            ' Find champion by validating against test data set (compromises test data set for future eval)'
            g, s = self.test_data[0]

            # Run top training model over evaluation dataset to get elite_val fitness score and index
            fwd, w0, wp, g_n = self.models[self.sorted[0]](g)
            elite_val = Forward.fitness_f(
                fwd, s)  # - self.custom_loss(w0,wp,g_n,gen)
            self.elite_eval[0] = elite_val

            # Run rest of top models over evaluation dataset, storing their scores and saving champion and its index
            for i in range(1, self.trunc_threshold):
                fwd, w0, wp, g_n = self.models[self.sorted[i]](g)
                self.elite_eval[i] = Forward.fitness_f(
                    fwd, s)  #- self.custom_loss(w0,wp,g_n,gen)
                if elite_val < self.elite_eval[i]:
                    # Swap current champion index in self.sorted with index of model that out-performed it in eval
                    # Technically the sorted list is no longer in order, but this does not matter as the top models
                    # are all still top models and the champion is in the champion position
                    elite_val = self.elite_eval[i]
                    former_champ_idx = self.sorted[0]
                    self.sorted[0] = self.sorted[i]
                    self.sorted[i] = former_champ_idx

            ' Copy models over truncation barrier randomly into bottom models w/ mutation '
            # Generate array of random indices corresponding to models above trunc barrier, and collect mutation arrays
            rand_model_p = self.collect_random_mutations()
            rand_top = np.random.randint(self.trunc_threshold,
                                         size=self.num_models -
                                         self.trunc_threshold)
            for i in range(self.trunc_threshold, self.num_models):
                # Grab all truncated models
                m = self.models[self.sorted[i]]

                # Grab random top model
                m_t = self.models[self.sorted[rand_top[i -
                                                       self.trunc_threshold]]]

                for (mp, mtp, mut) in zip(m.parameters(), m_t.parameters(),
                                          rand_model_p):
                    # Copy top model parameters chosen into bottom module and mutate with random tensor of same size
                    # New random tensor vals drawn from normal distn center=0 var=1, multiplied by mutation power
                    mp.copy_(mtp).add_(mut[i])

            ' Mutate top models that are not champion '
            for i in range(1, self.trunc_threshold):
                # Mutate all elite models except champion
                for (mp, mut) in zip(m.parameters(), rand_model_p):
                    # Add random tensor vals drawn from normal distn center=0 var=1, multiplied by mutation power
                    mp.add_(mut[i])

            ' Synchronize all operations so that models all mutate and copy before next generation'
            torch.cuda.synchronize(self.device)
            #self.save_plots(gen,plot_arr=[0,18,36,63])

    def collect_random_mutations(self):
        rand_model_p = []

        rand_lay_1 = torch.mul(
            torch.randn(self.num_models,
                        2,
                        4,
                        requires_grad=False,
                        device=self.device), self.mut)
        rand_model_p.append(rand_lay_1)

        rand_bi_1 = torch.mul(
            torch.randn(self.num_models,
                        2,
                        requires_grad=False,
                        device=self.device), self.mut)
        rand_model_p.append(rand_bi_1)

        rand_lay_2 = torch.mul(
            torch.randn(self.num_models,
                        3,
                        2,
                        requires_grad=False,
                        device=self.device), self.mut)
        rand_model_p.append(rand_lay_2)

        rand_bi_2 = torch.mul(
            torch.randn(self.num_models,
                        3,
                        requires_grad=False,
                        device=self.device), self.mut)
        rand_model_p.append(rand_bi_2)

        rand_lay_3 = torch.mul(
            torch.randn(self.num_models,
                        3,
                        3,
                        requires_grad=False,
                        device=self.device), self.mut)
        rand_model_p.append(rand_lay_3)

        rand_bi_3 = torch.mul(
            torch.randn(self.num_models,
                        3,
                        requires_grad=False,
                        device=self.device), self.mut)
        rand_model_p.append(rand_bi_3)
        '''
        rand_lay_1 = torch.mul(torch.randn(self.num_models, 100, 2, requires_grad=False, device=self.device),
                                   self.mut)
        rand_model_p.append(rand_lay_1)

        rand_bi_1 = torch.mul(torch.randn(self.num_models, 100, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bi_1)

        rand_lay_2 = torch.mul(torch.randn(self.num_models, 100, 100, requires_grad=False, device=self.device),
                                   self.mut)
        rand_model_p.append(rand_lay_2)

        rand_bi_2 = torch.mul(torch.randn(self.num_models, 100, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bi_2)

        #Batchnorm mods
        rand_bn_lin_1 =torch.mul(torch.randn(self.num_models, 100, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_lin_1)

        rand_bn_bias_1 =torch.mul(torch.randn(self.num_models, 100, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_bias_1)

        rand_bn_lin_2 =torch.mul(torch.randn(self.num_models, 100, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_lin_2)

        rand_bn_bias_2 =torch.mul(torch.randn(self.num_models, 100, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_bias_2)

        rand_lin_g = torch.mul(torch.randn(self.num_models, 1, 100, requires_grad=False, device=self.device),
                                   self.mut)
        rand_model_p.append(rand_lin_g)

        rand_lin_w0 = torch.mul(torch.randn(self.num_models, 1, 100, requires_grad=False, device=self.device),
                                    self.mut)
        rand_model_p.append(rand_lin_w0)

        rand_lin_wp = torch.mul(torch.randn(self.num_models, 1, 100, requires_grad=False, device=self.device),
                                    self.mut)
        rand_model_p.append(rand_lin_wp)

        #Batchnorm mods
        rand_bn_w0_w =torch.mul(torch.randn(self.num_models, 1, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_w0_w)
        rand_bn_w0_b =torch.mul(torch.randn(self.num_models, 1, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_w0_b)
        rand_bn_wp_w =torch.mul(torch.randn(self.num_models, 1, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_wp_w)
        rand_bn_wp_b =torch.mul(torch.randn(self.num_models, 1, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_wp_b)
        rand_bn_g_w =torch.mul(torch.randn(self.num_models, 1, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_g_w)
        rand_bn_g_b =torch.mul(torch.randn(self.num_models, 1, requires_grad=False, device=self.device),
                                  self.mut)
        rand_model_p.append(rand_bn_g_b)
        '''
        return rand_model_p

    def save_plots(self,
                   gen,
                   rate=10,
                   plot_arr=[0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90]):
        if gen % rate == 0:
            fig = plt.figure()

            g, s = self.train_data[0]
            for subplot, idx in zip(range(1, len(plot_arr) + 1, 1), plot_arr):
                t = g[idx].view((1, 2))
                champ = self.models[self.sorted[0]]
                champ_out, w, wp, gn = champ(t)
                champ_out = champ_out[0].cpu().numpy()
                w_elite = self.models[self.sorted[self.trunc_threshold - 1]]
                w_elite_out, w, wp, gn = w_elite(t)
                w_elite_out = w_elite_out[0].cpu().numpy()
                worst = self.models[self.sorted[-1]]
                worst_out, w, wp, gn = worst(t)
                worst_out = worst_out[0].cpu().numpy()
                sn = s[idx].cpu().numpy()

                self.lorentz_plot.append(champ_out)
                self.lorentz_plot.append(w_elite_out)
                self.lorentz_plot.append(worst_out)
                self.lorentz_plot.append(sn)

                ax = fig.add_subplot(1, len(plot_arr), subplot)
                ax.plot(np.linspace(0.5, 5, 300),
                        worst_out,
                        color='tab:red',
                        label='worst')
                ax.plot(np.linspace(0.5, 5, 300),
                        w_elite_out,
                        color='tab:purple',
                        label='worst elite')
                ax.plot(np.linspace(0.5, 5, 300),
                        champ_out,
                        color='tab:blue',
                        label='Champion')
                ax.plot(np.linspace(0.5, 5, 300),
                        sn,
                        color='tab:orange',
                        label='Truth Spectra')
            plt.legend()
            plt.xlabel("Frequency (THz)")
            plt.ylabel("e2")

            self.writer.add_figure("{}_Lorentz_Evolution".format(gen), fig)

            fig = plt.figure()
            ax = fig.add_subplot(1, 1, 1)
            champ_out, w, wp, gn = self.models[self.sorted[0]](g)
            hist_ydata = Forward.fitness_by_sample(champ_out, s).cpu().numpy()
            #hist_ydata = (Forward.fitness_by_sample(champ_out, s) - self.custom_loss_1(w,wp,gn,gen)).cpu().numpy()
            self.hist_plot.append(hist_ydata)
            ax.hist(hist_ydata, bins=np.arange(0, 5, 0.05))
            #ax.set_xticks([0,0.1,0.3,0.5,0.8,1,2,5,10,20,50])
            #plt.xscale("log")
            plt.xlabel("MSE_Loss over training")
            plt.ylabel("Number of datasets")

            self.writer.add_figure("{}_Histogram".format(gen), fig)
            plt.close()

    def custom_loss_1(self, w0, wp, g, gen):
        custom_loss = 0
        if w0 is not None:
            freq_mean = (self.flags['freq_low'] + self.flags['freq_high']) / 2
            freq_range = (self.flags['freq_high'] - self.flags['freq_low']) / 2
            custom_loss += torch.sum(
                torch.relu(torch.abs(w0 - freq_mean) - freq_range), 1)
        if g is not None:
            if gen is not None and gen < 100:
                custom_loss += torch.sum(torch.relu(-g + 0.05), 1)
            else:
                custom_loss += 100 * torch.sum(torch.relu(-g), 1)
        if wp is not None:
            custom_loss += 100 * torch.sum(torch.relu(-wp), 1)
        return custom_loss

    def custom_loss(self, w0, wp, g, gen):
        custom_loss = 0
        if w0 is not None:
            freq_mean = (self.flags['freq_low'] + self.flags['freq_high']) / 2
            freq_range = (self.flags['freq_high'] - self.flags['freq_low']) / 2
            custom_loss += torch.sum(
                torch.relu(torch.abs(w0 - freq_mean) - freq_range))
        if g is not None:
            if gen is not None and gen < 100:
                custom_loss += torch.sum(torch.relu(-g + 0.05))
            else:
                custom_loss += 100 * torch.sum(torch.relu(-g))
        if wp is not None:
            custom_loss += 100 * torch.sum(torch.relu(-wp))
        return custom_loss
예제 #3
0
def optimize(
    model: ModelType,
    fun_data: FunDataType,
    fun_loss: FunLossType,
    plotfuns: PlotFunsType,
    optimizer_kind='Adam',
    max_epoch=100,
    patience=20,  # How many epochs to wait before quitting
    thres_patience=0.001,  # How much should it improve wi patience
    learning_rate=.5,
    reduce_lr_by=0.5,
    reduced_lr_on_epoch=0,
    reduce_lr_after=50,
    reset_lr_after=100,
    to_plot_progress=True,
    show_progress_every=5,  # number of epochs
    to_print_grad=True,
    n_fold_valid=1,
    epoch_to_check=None,  # CHECKED
    comment='',
    **kwargs  # to ignore unnecessary kwargs
) -> (float, dict, dict, List[float], List[float]):
    """

    :param model:
    :param fun_data: (mode='all'|'train'|'valid'|'train_valid'|'test',
    fold_valid=0, epoch=0, n_fold_valid=1) -> (data, target)
    :param fun_loss: (out, target) -> loss
    :param plotfuns: [(str, fun)] where fun takes dict d with keys
    'data_*', 'target_*', 'out_*', 'loss_*', where * = 'train', 'valid', etc.
    :param optimizer_kind:
    :param max_epoch:
    :param patience:
    :param thres_patience:
    :param learning_rate:
    :param reduce_lr_by:
    :param reduced_lr_on_epoch:
    :param reduce_lr_after:
    :param to_plot_progress:
    :param show_progress_every:
    :param to_print_grad:
    :param n_fold_valid:
    :param kwargs:
    :return: loss_test, best_state, d, losses_train, losses_valid where d
    contains 'data_*', 'target_*', 'out_*', and 'loss_*', where * is
    'train_valid', 'test', and 'all'.
    """
    def get_optimizer(model, lr):
        if optimizer_kind == 'SGD':
            return optim.SGD(model.parameters(), lr=lr)
        elif optimizer_kind == 'Adam':
            return optim.Adam(model.parameters(), lr=lr)
        elif optimizer_kind == 'LBFGS':
            return optim.LBFGS(model.parameters(), lr=lr)
        else:
            raise NotImplementedError()

    learning_rate0 = learning_rate
    optimizer = get_optimizer(model, learning_rate)

    best_loss_epoch = 0
    best_loss_valid = np.inf
    best_state = model.state_dict()
    best_losses = []

    # CHECKED storing and loading states
    state0 = None
    loss0 = None
    data0 = None
    target0 = None
    out0 = None
    outs0 = None

    def array2str(v):
        return ', '.join(['%1.2g' % v1 for v1 in v.flatten()[:10]])

    def print_targ_out(target0, out0, outs0, loss0):
        print('target:\n' + array2str(target0))
        print('outs:\n' + '\n'.join(['[%s]' % array2str(v) for v in outs0]))
        print('out:\n' + array2str(out0))
        print('loss: ' + '%g' % loss0)

    def fun_outs(model, data):
        p_bef_lapse0 = model.dtb(*data)[0].detach().clone()
        p_aft_lapse0 = model.lapse(p_bef_lapse0).detach().clone()
        return [p_bef_lapse0, p_aft_lapse0]

    def are_all_equal(outs, outs0):
        for i, (out1, out0) in enumerate(zip(outs, outs0)):
            if (out1 != out0).any():
                warnings.warn('output %d different! max diff = %g' %
                              (i, (out1 - out0).abs().max()))
                print('--')

    # losses_train[epoch] = average cross-validated loss for the epoch
    losses_train = []
    losses_valid = []

    if to_plot_progress:
        writer = SummaryWriter(comment=comment)
    t_st = time.time()
    epoch = 0

    try:
        for epoch in range(max([max_epoch, 1])):
            losses_fold_train = []
            losses_fold_valid = []
            for i_fold in range(n_fold_valid):
                # NOTE: Core part
                data_train, target_train = fun_data('train', i_fold, epoch,
                                                    n_fold_valid)
                model.train()
                if optimizer_kind == 'LBFGS':

                    def closure():
                        optimizer.zero_grad()
                        out_train = model(data_train)
                        loss = fun_loss(out_train, target_train)
                        loss.backward()
                        return loss

                    if max_epoch > 0:
                        optimizer.step(closure)
                    out_train = model(data_train)
                    loss_train1 = fun_loss(out_train, target_train)
                    raise NotImplementedError(
                        'Restoring best state is not implemented yet')
                else:
                    optimizer.zero_grad()
                    out_train = model(data_train)
                    loss_train1 = fun_loss(out_train, target_train)
                    # DEBUGGED: optimizer.step() must not be taken before
                    #  storing best_loss or best_state

                losses_fold_train.append(loss_train1)

                if n_fold_valid == 1:
                    out_valid = npt.tensor(npy(out_train))
                    loss_valid1 = npt.tensor(npy(loss_train1))
                    data_valid = data_train
                    target_valid = target_train

                    # DEBUGGED: Unless directly assigned, target_valid !=
                    #  target_train when n_fold_valid = 1, which doesn't make
                    #  sense. Suggests a bug in fun_data when n_fold = 1
                else:
                    model.eval()
                    data_valid, target_valid = fun_data(
                        'valid', i_fold, epoch, n_fold_valid)
                    out_valid = model(data_valid)
                    loss_valid1 = fun_loss(out_valid, target_valid)
                    model.train()
                losses_fold_valid.append(loss_valid1)

            loss_train = torch.mean(torch.stack(losses_fold_train))
            loss_valid = torch.mean(torch.stack(losses_fold_valid))
            losses_train.append(npy(loss_train))
            losses_valid.append(npy(loss_valid))

            if to_plot_progress:
                writer.add_scalar('loss_train', loss_train, global_step=epoch)
                writer.add_scalar('loss_valid', loss_valid, global_step=epoch)

            # --- Store best loss
            # NOTE: storing losses/states must happen BEFORE taking a step!
            if loss_valid < best_loss_valid:
                # is_best = True
                best_loss_epoch = deepcopy(epoch)
                best_loss_valid = npt.tensor(npy(loss_valid))
                best_state = model.state_dict()

            best_losses.append(best_loss_valid)

            # CHECKED storing and loading state
            if epoch == epoch_to_check:
                loss0 = loss_valid.detach().clone()
                state0 = model.state_dict()
                data0 = deepcopy(data_valid)
                target0 = deepcopy(target_valid)
                out0 = out_valid.detach().clone()
                outs0 = fun_outs(model, data0)

                loss001 = fun_loss(out0, target0)
                # CHECKED: loss001 must equal loss0
                print('loss001 - loss0: %g' % (loss001 - loss0))

                print_targ_out(target0, out0, outs0, loss0)
                print('--')

            def print_loss():
                t_el = time.time() - t_st
                print('%1.0f sec/%d epochs = %1.1f sec/epoch, Ltrain: %f, '
                      'Lvalid: %f, LR: %g, best: %f, epochB: %d' %
                      (t_el, epoch + 1, t_el /
                       (epoch + 1), loss_train, loss_valid, learning_rate,
                       best_loss_valid, best_loss_epoch))

            if epoch % show_progress_every == 0:
                model.train()
                data_train_valid, target_train_valid = fun_data(
                    'train_valid', i_fold, epoch, n_fold_valid)
                out_train_valid = model(data_train_valid)
                loss_train_valid = fun_loss(out_train_valid,
                                            target_train_valid)
                print_loss()
                if to_plot_progress:
                    d = {
                        'data_train': data_train,
                        'data_valid': data_valid,
                        'data_train_valid': data_train_valid,
                        'out_train': out_train.detach(),
                        'out_valid': out_valid.detach(),
                        'out_train_valid': out_train_valid.detach(),
                        'target_train': target_train.detach(),
                        'target_valid': target_valid.detach(),
                        'target_train_valid': target_train_valid.detach(),
                        'loss_train': loss_train.detach(),
                        'loss_valid': loss_valid.detach(),
                        'loss_train_valid': loss_train_valid.detach()
                    }

                    for k, f in odict(plotfuns).items():
                        fig, d = f(model, d)
                        if fig is not None:
                            writer.add_figure(k, fig, global_step=epoch)

            # --- Learning rate reduction and patience
            # if epoch == reduced_lr_on_epoch + reset_lr_after
            # if epoch == reduced_lr_on_epoch + reduce_lr_after and (
            #         best_loss_valid
            #         > best_losses[-reduce_lr_after] - thres_patience
            # ):
            if epoch > 0 and epoch % reset_lr_after == 0:
                learning_rate = learning_rate0
            elif epoch > 0 and epoch % reduce_lr_after == 0:
                learning_rate *= reduce_lr_by
                optimizer = get_optimizer(model, learning_rate)
                reduced_lr_on_epoch = epoch

            if epoch >= patience and (best_loss_valid >
                                      best_losses[-patience] - thres_patience):
                print('Ran out of patience!')
                if to_print_grad:
                    print_grad(model)
                break

            # --- Take a step
            if optimizer_kind != 'LBFGS':
                # steps are not taken above for n_fold_valid == 1, so take a
                # step here, after storing the best state
                loss_train.backward()
                if to_print_grad and epoch == 0:
                    print_grad(model)
                if max_epoch > 0:
                    optimizer.step()

    except Exception as ex:
        from lib.pylabyk.cacheutil import is_keyboard_interrupt
        if not is_keyboard_interrupt(ex):
            raise ex
        print('fit interrupted by user at epoch %d' % epoch)

        from lib.pylabyk.localfile import LocalFile, datetime4filename
        localfile = LocalFile()
        cache = localfile.get_cache('model_data_target')
        data_train_valid, target_train_valid = fun_data(
            'all', 0, 0, n_fold_valid)
        cache.set({
            'model': model,
            'data_train_valid': data_train_valid,
            'target_train_valid': target_train_valid
        })
        cache.save()

    print_loss()
    if to_plot_progress:
        writer.close()

    if epoch_to_check is not None:
        # Must print the same output as previous call to print_targ_out
        print_targ_out(target0, out0, outs0, loss0)

        model.load_state_dict(state0)
        state1 = model.state_dict()
        for (key0, param0), (key1, param1) in zip(state0.items(), state1.items(
        )):  # type: ((str, torch.Tensor), (str, torch.Tensor))
            if (param0 != param1).any():
                with torch.no_grad():
                    warnings.warn(
                        'Strange! loaded %s = %s\n'
                        '!= stored %s = %s\n'
                        'loaded - stored = %s' %
                        (key1, param1, key0, param0, param1 - param0))
        data, target = fun_data('valid', 0, epoch_to_check, n_fold_valid)

        if not torch.is_tensor(data):
            p_unequal = torch.tensor([(v1 != v0).double().mean()
                                      for v1, v0 in zip(data, data0)])
            if (p_unequal > 0).any():
                print('Strange! loaded data != stored data0\n'
                      'Proportion: %s' % p_unequal)
            else:
                print('All loaded data == stored data')
        elif (data != data0).any():
            print('Strange! loaded data != stored data0')
        else:
            print('All loaded data == stored data')

        if (target != target0).any():
            print('Strange! loaded target != stored target0')
        else:
            print('All loaded target == stored target')

        print_targ_out(target0, out0, outs0, loss0)

        # with torch.no_grad():
        #     out01 = model(data0)
        #     loss01 = fun_loss(out01, target0)
        model.train()
        # with torch.no_grad():
        # CHECKED
        # outs1 = fun_outs(model, data)
        # are_all_equal(outs1, outs0)

        out1 = model(data)
        if (out0 != out1).any():
            warnings.warn('Strange! out from loaded params != stored out\n'
                          'Max abs(loaded - stored): %g' %
                          (out1 - out0).abs().max())
            print('--')
        else:
            print('out from loaded params = stored out')

        loss01 = fun_loss(out0, target0)
        print_targ_out(target0, out0, outs0, loss01)

        if loss0 != loss01:
            warnings.warn(
                'Strange!  loss1 = %g simply computed again with out0, '
                'target0\n'
                '!= stored loss0 = %g\n'
                'loaded - stored:  %g\n'
                'Therefore, fun_loss, out0, or target0 has changed!' %
                (loss01, loss0, loss01 - loss0))
            print('--')
        else:
            print('loss0 == loss01, simply computed again with out0, target0')

        loss1 = fun_loss(out1, target)
        if loss0 != loss1:
            warnings.warn('Strange!  loss1 = %g from loaded params\n'
                          '!= stored loss0 = %g\n'
                          'loaded - stored:  %g' %
                          (loss1, loss0, loss1 - loss0))
            print('--')
        else:
            print('loss1 = %g = loss0 = %g' % (loss1, loss0))

        loss10 = fun_loss(out1, target0)
        if loss0 != loss1:
            warnings.warn(
                'Strange!  loss10 = %g from loaded params and stored '
                'target0\n'
                '!= stored loss0 = %g\n'
                'loaded - stored:  %g' % (loss10, loss0, loss10 - loss0))
            print('--')
        else:
            print('loss10 = %g = loss10 = %g' % (loss1, loss0))
        print('--')

    model.load_state_dict(best_state)

    d = {}
    for mode in ['train_valid', 'valid', 'test', 'all']:
        data, target = fun_data(mode, 0, 0, n_fold_valid)
        out = model(data)
        loss = fun_loss(out, target)
        d.update({
            'data_' + mode: data,
            'target_' + mode: target,
            'out_' + mode: npt.tensor(npy(out)),
            'loss_' + mode: npt.tensor(npy(loss))
        })

    if d['loss_valid'] != best_loss_valid:
        print('d[loss_valid]      = %g from loaded best_state \n'
              '!= best_loss_valid = %g\n'
              'd[loss_valid] - best_loss_valid = %g' %
              (d['loss_valid'], best_loss_valid,
               d['loss_valid'] - best_loss_valid))
        print('--')

    if isinstance(model, OverriddenParameter):
        print(model.__str__())
    elif isinstance(model, BoundedModule):
        pprint(model._parameters_incl_bounded)
    else:
        pprint(model.state_dict())

    return d['loss_test'], best_state, d, losses_train, losses_valid
예제 #4
0
파일: snl.py 프로젝트: yyht/lfi
class SNL:
    """
    Implementation of
    'Sequential Neural Likelihood: Fast Likelihood-free Inference with Autoregressive Flows'
    Papamakarios et al.
    AISTATS 2019
    https://arxiv.org/abs/1805.07226
    """
    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        neural_likelihood,
        mcmc_method="slice-np",
        summary_writer=None,
    ):
        """

        :param simulator: Python object with 'simulate' method which takes a torch.Tensor
        of parameter values, and returns a simulation result for each parameter as a torch.Tensor.
        :param prior: Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor containing the observation x0 for which to
        perform inference on the posterior p(theta | x0).
        :param neural_likelihood: Conditional density estimator q(x | theta) in the form of an
        nets.Module. Must have 'log_prob' and 'sample' methods.
        :param mcmc_method: MCMC method to use for posterior sampling. Must be one of
        ['slice', 'hmc', 'nuts'].
        """

        self._simulator = simulator
        self._prior = prior
        self._true_observation = true_observation
        self._neural_likelihood = neural_likelihood
        self._mcmc_method = mcmc_method

        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            neural_likelihood=self._neural_likelihood,
            prior=self._prior,
            true_observation=self._true_observation,
        )

        # TODO: decide on Slice Sampling implementation
        target_log_prob = (lambda parameters: self._neural_likelihood.log_prob(
            inputs=self._true_observation.reshape(1, -1),
            context=torch.Tensor(parameters).reshape(1, -1),
        ).item() + self._prior.log_prob(torch.Tensor(parameters)).sum().item())
        self._neural_likelihood.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._neural_likelihood.train()

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank = [], []

        # Each SNL run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(utils.get_log_root(), "snl", simulator.name,
                                   utils.get_timestamp())
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "mcmc-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
        }

    def run_inference(self, num_rounds, num_simulations_per_round):
        """
        This runs SNL for num_rounds rounds, using num_simulations_per_round calls to
        the simulator per round.

        :param num_rounds: Number of rounds to run.
        :param num_simulations_per_round: Number of simulator calls per round.
        :return: None
        """

        round_description = ""
        tbar = tqdm(range(num_rounds))
        for round_ in tbar:

            tbar.set_description(round_description)

            # Generate parameters from prior in first round, and from most recent posterior
            # estimate in subsequent rounds.
            if round_ == 0:
                parameters, observations = simulators.simulation_wrapper(
                    simulator=self._simulator,
                    parameter_sample_fn=lambda num_samples: self._prior.sample(
                        (num_samples, )),
                    num_samples=num_simulations_per_round,
                )
            else:
                parameters, observations = simulators.simulation_wrapper(
                    simulator=self._simulator,
                    parameter_sample_fn=lambda num_samples: self.
                    sample_posterior(num_samples),
                    num_samples=num_simulations_per_round,
                )

            # Store (parameter, observation) pairs.
            self._parameter_bank.append(torch.Tensor(parameters))
            self._observation_bank.append(torch.Tensor(observations))

            # Fit neural likelihood to newly aggregated dataset.
            self._fit_likelihood()

            # Update description for progress bar.
            round_description = (
                f"-------------------------\n"
                f"||||| ROUND {round_ + 1} STATS |||||:\n"
                f"-------------------------\n"
                f"Epochs trained: {self._summary['epochs'][-1]}\n"
                f"Best validation performance: {self._summary['best-validation-log-probs'][-1]:.4f}\n\n"
            )

            # Update TensorBoard and summary dict.
            self._summarize(round_)

    def sample_posterior(self, num_samples, thin=1):
        """
        Samples from posterior for true observation q(theta | x0) ~ q(x0 | theta) p(theta)
        using most recent likelihood estimate q(x0 | theta) with MCMC.

        :param num_samples: Number of samples to generate.
        :param thin: Generate (num_samples * thin) samples in total, then select every
        'thin' sample.
        :return: torch.Tensor of shape [num_samples, parameter_dim]
        """

        # Always sample in eval mode.
        self._neural_likelihood.eval()

        if self._mcmc_method == "slice-np":
            self.posterior_sampler.gen(20)
            samples = torch.Tensor(self.posterior_sampler.gen(num_samples))

        else:
            if self._mcmc_method == "slice":
                kernel = Slice(potential_function=self._potential_function)
            elif self._mcmc_method == "hmc":
                kernel = HMC(potential_fn=self._potential_function)
            elif self._mcmc_method == "nuts":
                kernel = NUTS(potential_fn=self._potential_function)
            else:
                raise ValueError(
                    "'mcmc_method' must be one of ['slice', 'hmc', 'nuts'].")
            num_chains = mp.cpu_count() - 1

            # TODO: decide on way to initialize chain
            initial_params = self._prior.sample((num_chains, ))
            sampler = MCMC(
                kernel=kernel,
                num_samples=num_samples // num_chains + num_chains,
                warmup_steps=200,
                initial_params={"": initial_params},
                num_chains=num_chains,
            )
            sampler.run()
            samples = next(iter(sampler.get_samples().values())).reshape(
                -1, self._simulator.parameter_dim)

            samples = samples[:num_samples].to(device)
            assert samples.shape[0] == num_samples

        # Back to training mode.
        self._neural_likelihood.train()

        return samples

    def _fit_likelihood(
        self,
        batch_size=100,
        learning_rate=5e-4,
        validation_fraction=0.1,
        stop_after_epochs=20,
    ):
        """
        Trains the conditional density estimator for the likelihood by maximum likelihood
        on the most recently aggregated bank of (parameter, observation) pairs.
        Uses early stopping on a held-out validation set as a terminating condition.

        :param batch_size: Size of batch to use for training.
        :param learning_rate: Learning rate for Adam optimizer.
        :param validation_fraction: The fraction of data to use for validation.
        :param stop_after_epochs: The number of epochs to wait for improvement on the
        validation set before terminating training.
        :return: None
        """

        # Get total number of training examples.
        num_examples = torch.cat(self._parameter_bank).shape[0]

        # Select random train and validation splits from (parameter, observation) pairs.
        permuted_indices = torch.randperm(num_examples)
        num_training_examples = int((1 - validation_fraction) * num_examples)
        num_validation_examples = num_examples - num_training_examples
        train_indices, val_indices = (
            permuted_indices[:num_training_examples],
            permuted_indices[num_training_examples:],
        )

        # Dataset is shared for training and validation loaders.
        dataset = data.TensorDataset(torch.cat(self._observation_bank),
                                     torch.cat(self._parameter_bank))

        # Create train and validation loaders using a subset sampler.
        train_loader = data.DataLoader(
            dataset,
            batch_size=batch_size,
            drop_last=True,
            sampler=SubsetRandomSampler(train_indices),
        )
        val_loader = data.DataLoader(
            dataset,
            batch_size=min(batch_size, num_examples - num_training_examples),
            shuffle=False,
            drop_last=False,
            sampler=SubsetRandomSampler(val_indices),
        )

        optimizer = optim.Adam(self._neural_likelihood.parameters(),
                               lr=learning_rate)
        # Keep track of best_validation log_prob seen so far.
        best_validation_log_prob = -1e100
        # Keep track of number of epochs since last improvement.
        epochs_since_last_improvement = 0
        # Keep track of model with best validation performance.
        best_model_state_dict = None

        epochs = 0
        while True:

            # Train for a single epoch.
            self._neural_likelihood.train()
            for batch in train_loader:
                optimizer.zero_grad()
                inputs, context = batch[0].to(device), batch[1].to(device)
                log_prob = self._neural_likelihood.log_prob(inputs,
                                                            context=context)
                loss = -torch.mean(log_prob)
                loss.backward()
                clip_grad_norm_(self._neural_likelihood.parameters(),
                                max_norm=5.0)
                optimizer.step()

            epochs += 1

            # Calculate validation performance.
            self._neural_likelihood.eval()
            log_prob_sum = 0
            with torch.no_grad():
                for batch in val_loader:
                    inputs, context = batch[0].to(device), batch[1].to(device)
                    log_prob = self._neural_likelihood.log_prob(
                        inputs, context=context)
                    log_prob_sum += log_prob.sum().item()
            validation_log_prob = log_prob_sum / num_validation_examples

            # Check for improvement in validation performance over previous epochs.
            if validation_log_prob > best_validation_log_prob:
                best_validation_log_prob = validation_log_prob
                epochs_since_last_improvement = 0
                best_model_state_dict = deepcopy(
                    self._neural_likelihood.state_dict())
            else:
                epochs_since_last_improvement += 1

            # If no validation improvement over many epochs, stop training.
            if epochs_since_last_improvement > stop_after_epochs - 1:
                self._neural_likelihood.load_state_dict(best_model_state_dict)
                break

        # Update summary.
        self._summary["epochs"].append(epochs)
        self._summary["best-validation-log-probs"].append(
            best_validation_log_prob)

    @property
    def summary(self):
        return self._summary

    def _summarize(self, round_):

        # Update summaries.
        try:
            mmd = utils.unbiased_mmd_squared(
                self._parameter_bank[-1],
                self._simulator.get_ground_truth_posterior_samples(
                    num_samples=1000),
            )
            print(mmd.item())
            self._summary["mmds"].append(mmd.item())
        except:
            pass

        median_observation_distance = torch.median(
            torch.sqrt(
                torch.sum(
                    (self._observation_bank[-1] -
                     self._true_observation.reshape(1, -1))**2,
                    dim=-1,
                )))
        self._summary["median-observation-distances"].append(
            median_observation_distance.item())

        negative_log_prob_true_parameters = -utils.gaussian_kde_log_eval(
            samples=self._parameter_bank[-1],
            query=self._simulator.get_ground_truth_parameters().reshape(1, -1),
        )
        self._summary["negative-log-probs-true-parameters"].append(
            negative_log_prob_true_parameters.item())

        # Plot most recently sampled parameters in TensorBoard.
        parameters = utils.tensor2numpy(self._parameter_bank[-1])
        figure = utils.plot_hist_marginals(
            data=parameters,
            ground_truth=utils.tensor2numpy(
                self._simulator.get_ground_truth_parameters()).reshape(-1),
            lims=self._simulator.parameter_plotting_limits,
        )
        self._summary_writer.add_figure(tag="posterior-samples",
                                        figure=figure,
                                        global_step=round_ + 1)

        self._summary_writer.add_scalar(
            tag="epochs-trained",
            scalar_value=self._summary["epochs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="median-observation-distance",
            scalar_value=self._summary["median-observation-distances"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="negative-log-prob-true-parameters",
            scalar_value=self._summary["negative-log-probs-true-parameters"]
            [-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="best-validation-log-prob",
            scalar_value=self._summary["best-validation-log-probs"][-1],
            global_step=round_ + 1,
        )

        if self._summary["mmds"]:
            self._summary_writer.add_scalar(
                tag="mmd",
                scalar_value=self._summary["mmds"][-1],
                global_step=round_ + 1,
            )

        self._summary_writer.flush()
예제 #5
0
파일: sre.py 프로젝트: plcrodrigues/lfi
class SRE:
    """
    Implementation 'Sequential Ratio Estimation', as presented in
    'Likelihood-free MCMC with Amortized Approximate Likelihood Ratios'
    Hermans et al.
    Pre-print 2019
    https://arxiv.org/abs/1903.04057
    """
    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        classifier,
        num_atoms=-1,
        mcmc_method="slice-np",
        summary_net=None,
        retrain_from_scratch_each_round=False,
        summary_writer=None,
    ):
        """
        :param simulator: Python object with 'simulate' method which takes a torch.Tensor
        of parameter values, and returns a simulation result for each parameter as a torch.Tensor.
        :param prior: Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor containing the observation x0 for which to
        perform inference on the posterior p(theta | x0).
        :param classifier: Binary classifier in the form of an nets.Module.
        Takes as input (x, theta) pairs and outputs pre-sigmoid activations.
        :param num_atoms: int
            Number of atoms to use for classification.
            If -1, use all other parameters in minibatch.
        :param summary_net: Optional network which may be used to produce feature vectors
        f(x) for high-dimensional observations.
        :param retrain_from_scratch_each_round: Whether to retrain the conditional density
        estimator for the posterior from scratch each round.
        """

        self._simulator = simulator
        self._true_observation = true_observation
        self._classifier = classifier
        self._prior = prior

        assert isinstance(num_atoms,
                          int), "Number of atoms must be an integer."
        self._num_atoms = num_atoms

        self._mcmc_method = mcmc_method

        # We may want to summarize high-dimensional observations.
        # This may be either a fixed or learned transformation.
        if summary_net is None:
            self._summary_net = nn.Identity()
        else:
            self._summary_net = summary_net

        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            classifier, prior, true_observation)

        # TODO: decide on Slice Sampling implementation
        target_log_prob = (lambda parameters: self._classifier(
            torch.cat(
                (torch.Tensor(parameters), self._true_observation)).reshape(
                    1, -1)).item() + self._prior.log_prob(
                        torch.Tensor(parameters)).sum().item())
        self._classifier.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._classifier.train()

        self._retrain_from_scratch_each_round = retrain_from_scratch_each_round
        # If we're retraining from scratch each round,
        # keep a copy of the original untrained model for reinitialization.
        if retrain_from_scratch_each_round:
            self._untrained_classifier = deepcopy(classifier)
        else:
            self._untrained_classifier = None

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank = [], []

        # Each SRE run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(utils.get_log_root(), "sre", simulator.name,
                                   utils.get_timestamp())
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "mcmc-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
        }

    def run_inference(self, num_rounds, num_simulations_per_round):
        """
        This runs SRE for num_rounds rounds, using num_simulations_per_round calls to
        the simulator per round.

        :param num_rounds: Number of rounds to run.
        :param num_simulations_per_round: Number of simulator calls per round.
        :return: None
        """

        round_description = ""
        tbar = tqdm(range(num_rounds))
        for round_ in tbar:

            tbar.set_description(round_description)

            # Generate parameters from prior in first round, and from most recent posterior
            # estimate in subsequent rounds.
            if round_ == 0:
                parameters, observations = simulators.simulation_wrapper(
                    simulator=self._simulator,
                    parameter_sample_fn=lambda num_samples: self._prior.sample(
                        (num_samples, )),
                    num_samples=num_simulations_per_round,
                )
            else:
                parameters, observations = simulators.simulation_wrapper(
                    simulator=self._simulator,
                    parameter_sample_fn=lambda num_samples: self.
                    sample_posterior(num_samples),
                    num_samples=num_simulations_per_round,
                )

            # Store (parameter, observation) pairs.
            self._parameter_bank.append(torch.Tensor(parameters))
            self._observation_bank.append(torch.Tensor(observations))

            # Fit posterior using newly aggregated data set.
            self._fit_classifier()

            # Update description for progress bar.
            round_description = (
                f"-------------------------\n"
                f"||||| ROUND {round_ + 1} STATS |||||:\n"
                f"-------------------------\n"
                f"Epochs trained: {self._summary['epochs'][-1]}\n"
                f"Best validation performance: {self._summary['best-validation-log-probs'][-1]:.4f}\n\n"
            )

            # Update tensorboard and summary dict.
            self._summarize(round_)

    def sample_posterior(self, num_samples, thin=10):
        """
        Samples from posterior for true observation q(theta | x0) ~ r(x0, theta) p(theta)
        using most recent ratio estimate r(x0, theta) with MCMC.

        :param num_samples: Number of samples to generate.
        :param mcmc_method: Which MCMC method to use ['metropolis-hastings', 'slice', 'hmc', 'nuts']
        :param thin: Generate (num_samples * thin) samples in total, then select every
        'thin' sample.
        :return: torch.Tensor of shape [num_samples, parameter_dim]
        """

        # Always sample in eval mode.
        self._classifier.eval()

        if self._mcmc_method == "slice-np":
            self.posterior_sampler.gen(20)
            samples = torch.Tensor(self.posterior_sampler.gen(num_samples))

        else:
            if self._mcmc_method == "slice":
                kernel = Slice(potential_function=self._potential_function)
            elif self._mcmc_method == "hmc":
                kernel = HMC(potential_fn=self._potential_function)
            elif self._mcmc_method == "nuts":
                kernel = NUTS(potential_fn=self._potential_function)
            else:
                raise ValueError(
                    "'mcmc_method' must be one of ['slice', 'hmc', 'nuts'].")
            num_chains = mp.cpu_count() - 1

            initial_params = self._prior.sample((num_chains, ))
            sampler = MCMC(
                kernel=kernel,
                num_samples=(thin * num_samples) // num_chains + num_chains,
                warmup_steps=200,
                initial_params={"": initial_params},
                num_chains=num_chains,
                mp_context="spawn",
            )
            sampler.run()
            samples = next(iter(sampler.get_samples().values())).reshape(
                -1, self._simulator.parameter_dim)

            samples = samples[::thin][:num_samples]
            assert samples.shape[0] == num_samples

        # Back to training mode.
        self._classifier.train()

        return samples

    def _fit_classifier(
        self,
        batch_size=100,
        learning_rate=5e-4,
        validation_fraction=0.1,
        stop_after_epochs=20,
    ):
        """
        Trains the classifier by maximizing a Bernoulli likelihood which distinguishes
        between jointly distributed (parameter, observation) pairs and randomly chosen
        (parameter, observation) pairs.
        Uses early stopping on a held-out validation set as a terminating condition.

        :param batch_size: Size of batch to use for training.
        :param learning_rate: Learning rate for Adam optimizer.
        :param validation_fraction: The fraction of data to use for validation.
        :param stop_after_epochs: The number of epochs to wait for improvement on the
        validation set before terminating training.
        :return: None
        """

        # Get total number of training examples.
        num_examples = torch.cat(self._parameter_bank).shape[0]

        # Select random train and validation splits from (parameter, observation) pairs.
        permuted_indices = torch.randperm(num_examples)
        num_training_examples = int((1 - validation_fraction) * num_examples)
        num_validation_examples = num_examples - num_training_examples
        train_indices, val_indices = (
            permuted_indices[:num_training_examples],
            permuted_indices[num_training_examples:],
        )

        # Dataset is shared for training and validation loaders.
        dataset = data.TensorDataset(torch.cat(self._parameter_bank),
                                     torch.cat(self._observation_bank))

        # Create train and validation loaders using a subset sampler.
        train_loader = data.DataLoader(
            dataset,
            batch_size=batch_size,
            drop_last=True,
            sampler=SubsetRandomSampler(train_indices),
        )
        val_loader = data.DataLoader(
            dataset,
            batch_size=min(batch_size, num_examples - num_training_examples),
            shuffle=False,
            drop_last=False,
            sampler=SubsetRandomSampler(val_indices),
        )

        optimizer = optim.Adam(
            list(self._classifier.parameters()) +
            list(self._summary_net.parameters()),
            lr=learning_rate,
        )

        # Keep track of best_validation log_prob seen so far.
        best_validation_log_prob = -1e100
        # Keep track of number of epochs since last improvement.
        epochs_since_last_improvement = 0
        # Keep track of model with best validation performance.
        best_model_state_dict = None

        # If we're retraining from scratch each round, reset the neural posterior
        # to the untrained copy we made at the start.
        if self._retrain_from_scratch_each_round:
            self._classifier = deepcopy(self._classifier)

        def _get_log_prob(parameters, observations):

            # num_atoms = parameters.shape[0]
            num_atoms = self._num_atoms if self._num_atoms > 0 else batch_size

            repeated_observations = utils.repeat_rows(observations, num_atoms)

            # Choose between 1 and num_atoms - 1 parameters from the rest
            # of the batch for each observation.
            assert 0 < num_atoms - 1 < batch_size
            probs = ((1 /
                      (batch_size - 1)) * torch.ones(batch_size, batch_size) *
                     (1 - torch.eye(batch_size)))
            choices = torch.multinomial(probs,
                                        num_samples=num_atoms - 1,
                                        replacement=False)
            contrasting_parameters = parameters[choices]

            atomic_parameters = torch.cat(
                (parameters[:, None, :], contrasting_parameters),
                dim=1).reshape(batch_size * num_atoms, -1)

            inputs = torch.cat((atomic_parameters, repeated_observations),
                               dim=1)

            logits = self._classifier(inputs).reshape(batch_size, num_atoms)

            log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)

            return log_prob

        epochs = 0
        while True:

            # Train for a single epoch.
            self._classifier.train()
            for parameters, observations in train_loader:
                optimizer.zero_grad()
                log_prob = _get_log_prob(parameters, observations)
                loss = -torch.mean(log_prob)
                loss.backward()
                optimizer.step()

            epochs += 1

            # calculate validation performance
            self._classifier.eval()
            log_prob_sum = 0
            with torch.no_grad():
                for parameters, observations in val_loader:
                    log_prob = _get_log_prob(parameters, observations)
                    log_prob_sum += log_prob.sum().item()
                validation_log_prob = log_prob_sum / num_validation_examples

            # check for improvement
            if validation_log_prob > best_validation_log_prob:
                best_model_state_dict = deepcopy(self._classifier.state_dict())
                best_validation_log_prob = validation_log_prob
                epochs_since_last_improvement = 0
            else:
                epochs_since_last_improvement += 1

            # if no validation improvement over many epochs, stop training
            if epochs_since_last_improvement > stop_after_epochs - 1:
                self._classifier.load_state_dict(best_model_state_dict)
                break

        # Update summary.
        self._summary["epochs"].append(epochs)
        self._summary["best-validation-log-probs"].append(
            best_validation_log_prob)

    @property
    def summary(self):
        return self._summary

    def _summarize(self, round_):

        # Update summaries.
        try:
            mmd = utils.unbiased_mmd_squared(
                self._parameter_bank[-1],
                self._simulator.get_ground_truth_posterior_samples(
                    num_samples=1000),
            )
            self._summary["mmds"].append(mmd.item())
        except:
            pass

        median_observation_distance = torch.median(
            torch.sqrt(
                torch.sum(
                    (self._observation_bank[-1] -
                     self._true_observation.reshape(1, -1))**2,
                    dim=-1,
                )))
        self._summary["median-observation-distances"].append(
            median_observation_distance.item())

        negative_log_prob_true_parameters = -utils.gaussian_kde_log_eval(
            samples=self._parameter_bank[-1],
            query=self._simulator.get_ground_truth_parameters().reshape(1, -1),
        )
        self._summary["negative-log-probs-true-parameters"].append(
            negative_log_prob_true_parameters.item())

        # Plot most recently sampled parameters in TensorBoard.
        parameters = utils.tensor2numpy(self._parameter_bank[-1])
        figure = utils.plot_hist_marginals(
            data=parameters,
            ground_truth=utils.tensor2numpy(
                self._simulator.get_ground_truth_parameters()).reshape(-1),
            lims=self._simulator.parameter_plotting_limits,
        )
        self._summary_writer.add_figure(tag="posterior-samples",
                                        figure=figure,
                                        global_step=round_ + 1)

        self._summary_writer.add_scalar(
            tag="epochs-trained",
            scalar_value=self._summary["epochs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="best-validation-log-prob",
            scalar_value=self._summary["best-validation-log-probs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="median-observation-distance",
            scalar_value=self._summary["median-observation-distances"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="negative-log-prob-true-parameters",
            scalar_value=self._summary["negative-log-probs-true-parameters"]
            [-1],
            global_step=round_ + 1,
        )

        if self._summary["mmds"]:
            self._summary_writer.add_scalar(
                tag="mmd",
                scalar_value=self._summary["mmds"][-1],
                global_step=round_ + 1,
            )

        self._summary_writer.flush()
class DurationExtractor(nn.Module):
    """The teacher model for duration extraction"""
    def __init__(self,
                 adam_lr=0.002,
                 warmup_epochs=30,
                 init_scale=0.25,
                 guided_att_sigma=0.3,
                 device='cuda'):
        super(DurationExtractor, self).__init__()

        self.txt_encoder = ConvTextEncoder()
        self.audio_encoder = ConvAudioEncoder()
        self.audio_decoder = ConvAudioDecoder()
        self.attention = ScaledDotAttention()
        self.collate = Collate(device=device)

        # optim
        self.optimizer = torch.optim.Adam(self.parameters(), lr=adam_lr)
        self.scheduler = NoamScheduler(self.optimizer, warmup_epochs,
                                       init_scale)

        # losses
        self.loss_l1 = l1_masked
        self.loss_att = GuidedAttentionLoss(guided_att_sigma)

        # device
        self.device = device
        self.to(self.device)
        print(f'Model sent to {self.device}')

        # helper vars
        self.checkpoint = None
        self.epoch = 0
        self.step = 0

        repo = git.Repo(search_parent_directories=True)
        self.git_commit = repo.head.object.hexsha

    def to_device(self, device):
        print(f'Sending network to {device}')
        self.device = device
        self.to(device)
        return self

    def save(self):

        if self.checkpoint is not None:
            os.remove(self.checkpoint)
        self.checkpoint = os.path.join(
            self.logger.log_dir,
            f'{time.strftime("%Y-%m-%d")}_checkpoint_step{self.step}.pth')
        torch.save(
            {
                'epoch': self.epoch,
                'step': self.step,
                'state_dict': self.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
                'git_commit': self.git_commit
            }, self.checkpoint)

    def load(self, checkpoint):
        checkpoint = torch.load(checkpoint)
        self.epoch = checkpoint['epoch']
        self.step = checkpoint['step']
        self.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])

        commit = checkpoint['git_commit']
        if commit != self.git_commit:
            print(
                f'Warning: the loaded checkpoint was trained on commit {commit}, but you are on {self.git_commit}'
            )
        self.checkpoint = None  # prevent overriding old checkpoint
        return self

    def forward(self, phonemes, spectrograms, len_phonemes, training=False):
        """
        :param phonemes: (batch, alphabet, time), padded phonemes
        :param spectrograms: (batch, freq, time), padded spectrograms
        :param len_phonemes: list of phoneme lengths
        :return: decoded_spectrograms, attention_weights
        """
        spectrs = ZeroPad2d(
            (0, 0, 1, 0))(spectrograms)[:, :-1, :]  # move this to encoder?
        keys, values = self.txt_encoder(phonemes)
        queries = self.audio_encoder(spectrs)

        att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]),
                        lengths=len_phonemes,
                        dim=-1).to(self.device)

        if hp.positional_encoding:
            keys += positional_encoding(keys.shape[-1], keys.shape[1],
                                        w=hp.w).to(self.device)
            queries += positional_encoding(queries.shape[-1],
                                           queries.shape[1],
                                           w=1).to(self.device)

        attention, weights = self.attention(queries,
                                            keys,
                                            values,
                                            mask=att_mask)
        decoded = self.audio_decoder(attention + queries)
        return decoded, weights

    def generating(self, mode):
        """Put the module into mode for sequential generation"""
        for module in self.children():
            if hasattr(module, 'generating'):
                module.generating(mode)

    def generate(self,
                 phonemes,
                 len_phonemes,
                 steps=False,
                 window=3,
                 spectrograms=None):
        """Sequentially generate spectrogram from phonemes
        
        If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing)
        If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion
        Uses layer-level caching for faster inference.

        :param phonemes: Padded phoneme indices
        :param len_phonemes: Length of each sentence in `phonemes` (list of lengths)
        :param steps: How many steps to generate
        :param window: Window size for attention masking
        :param spectrograms: Padded spectrograms
        :return: Generated spectrograms
        """
        self.generating(True)
        self.train(False)

        assert steps or (spectrograms is not None)
        steps = steps if steps else spectrograms.shape[1]

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)
            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1],
                                            keys.shape[1],
                                            w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps,
                                         w=1).to(self.device)

            if spectrograms is None:
                dec = torch.zeros(len(phonemes),
                                  1,
                                  hp.out_channels,
                                  device=self.device)
            else:
                input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :]

            weights, decoded = None, None

            if window is not None:
                shape = (len(phonemes), 1, phonemes.shape[-1])
                idx = torch.zeros(len(phonemes), 1,
                                  phonemes.shape[-1]).to(phonemes.device)
                att_mask = idx_mask(shape, idx, window)
            else:
                att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                                lengths=len_phonemes,
                                dim=-1).to(self.device)

            for i in range(steps):
                if spectrograms is None:
                    queries = self.audio_encoder(dec)
                else:
                    queries = self.audio_encoder(input[:, i:i + 1, :])

                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                dec = self.audio_decoder(att + queries)
                weights = w if weights is None else torch.cat(
                    (weights, w), dim=1)
                decoded = dec if decoded is None else torch.cat(
                    (decoded, dec), dim=1)
                if window is not None:
                    idx = torch.argmax(w, dim=-1).unsqueeze(2).float()
                    att_mask = idx_mask(shape, idx, window)

        self.generating(False)
        return decoded, weights

    def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0, 1)):
        """Naive generation without layer-level caching for testing purposes"""

        self.train(False)

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)

            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1],
                                            keys.shape[1],
                                            w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps,
                                         w=1).to(self.device)

            dec = torch.zeros(len(phonemes),
                              1,
                              hp.out_channels,
                              device=self.device)

            weights = None

            att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                            lengths=len_phonemes,
                            dim=-1).to(self.device)

            for i in range(steps):
                print(i)
                queries = self.audio_encoder(dec)
                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                d = self.audio_decoder(att + queries)
                d = d[:, -1:]
                w = w[:, -1:]
                weights = w if weights is None else torch.cat(
                    (weights, w), dim=1)
                dec = torch.cat((dec, d), dim=1)

                if window is not None:
                    att_mask = median_mask(weights, window=window)

        return dec[:, 1:, :], weights

    def fit(self,
            batch_size,
            logdir,
            epochs=1,
            grad_clip=1,
            checkpoint_every=10):
        self.grad_clip = grad_clip
        self.logger = SummaryWriter(logdir)

        train_loader = self.train_dataloader(batch_size)
        valid_loader = self.val_dataloader(batch_size)

        # continue training from self.epoch if checkpoint loaded
        for e in range(self.epoch + 1, self.epoch + 1 + epochs):
            self.epoch = e
            train_losses = self._train_epoch(train_loader)
            valid_losses = self._validate(valid_loader)

            self.scheduler.step()
            self.logger.add_scalar('train/learning_rate',
                                   self.optimizer.param_groups[0]['lr'],
                                   self.epoch)
            if not e % checkpoint_every:
                self.save()

            print(
                f'Epoch {e} | Train - l1: {train_losses[0]}, guided_att: {train_losses[1]}| '
                f'Valid - l1: {valid_losses[0]}, guided_att: {valid_losses[1]}|'
            )

    def _train_epoch(self, dataloader):
        self.train()

        t_l1, t_att = 0, 0
        for i, batch in enumerate(Bar(dataloader)):
            self.optimizer.zero_grad()
            spectrs, slen, phonemes, plen, text = batch

            s = add_random_noise(spectrs, hp.noise)
            s = degrade_some(self,
                             s,
                             phonemes,
                             plen,
                             hp.feed_ratio,
                             repeat=hp.feed_repeat)
            s = frame_dropout(s, hp.replace_ratio)

            out, att_weights = self.forward(phonemes, s, plen)

            l1 = self.loss_l1(out, spectrs, slen)
            l_att = self.loss_att(att_weights, slen, plen)

            loss = l1 + l_att
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
            self.optimizer.step()
            self.step += 1

            t_l1 += l1.item()
            t_att += l_att.item()

            self.logger.add_scalar('batch/total', loss.item(), self.step)

        # report average cost per batch
        self.logger.add_scalar('train/l1', t_l1 / i, self.epoch)
        self.logger.add_scalar('train/guided_att', t_att / i, self.epoch)
        return t_l1 / i, t_att / i

    def _validate(self, dataloader):
        self.eval()

        t_l1, t_att = 0, 0
        for i, batch in enumerate(dataloader):
            spectrs, slen, phonemes, plen, text = batch
            # generate sequentially
            out, att_weights = self.generate(phonemes,
                                             plen,
                                             steps=spectrs.shape[1],
                                             window=None)

            # generate in supervised fashion - for visualisation only
            with torch.no_grad():
                out_s, att_s = self.forward(phonemes, spectrs, plen)

            l1 = self.loss_l1(out, spectrs, slen)
            l_att = self.loss_att(att_weights, slen, plen)
            t_l1 += l1.item()
            t_att += l_att.item()

            fig = display_spectr_alignment(
                out[-1, :slen[-1]], att_weights[-1][:slen[-1], :plen[-1]],
                out_s[-1, :slen[-1]], att_s[-1][:slen[-1], :plen[-1]],
                text[-1])
            self.logger.add_figure(text[-1], fig, self.epoch)

            if not self.epoch % 10:
                spec = self.collate.norm.inverse(
                    out[-1:])  # TODO: this fails if we do not standardize!
                sound, length = self.collate.stft.spec2wav(
                    spec.transpose(1, 2), slen[-1:])
                sound = sound[0, :length[0]]
                self.logger.add_audio(text[-1],
                                      sound.detach().cpu().numpy(),
                                      self.epoch,
                                      sample_rate=22050)  # TODO: parameterize

        # report average cost per batch
        self.logger.add_scalar('valid/l1', t_l1 / i, self.epoch)
        self.logger.add_scalar('valid/guided_att', t_att / i, self.epoch)
        return t_l1 / i, t_att / i

    def train_dataloader(self, batch_size):
        return DataLoader(AudioDataset(HPText.dataset,
                                       start_idx=0,
                                       end_idx=HPText.num_train,
                                       durations=False),
                          batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=True)

    def val_dataloader(self, batch_size):
        dataset = AudioDataset(HPText.dataset,
                               start_idx=HPText.num_train,
                               end_idx=HPText.num_valid,
                               durations=False)
        return DataLoader(dataset,
                          batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=False,
                          sampler=SequentialSampler(dataset))
예제 #7
0
class Network(object):
    def __init__(self,
                 model_fn,
                 flags,
                 train_loader,
                 test_loader,
                 ckpt_dir=os.path.join(os.path.abspath(''), 'models'),
                 inference_mode=False,
                 saved_model=None):
        self.model_fn = model_fn  # The model maker function
        self.flags = flags  # The Flags containing the specs
        if inference_mode:  # If inference mode, use saved model
            self.ckpt_dir = os.path.join(ckpt_dir, saved_model)
            self.saved_model = saved_model
            print("This is inference mode, the ckpt is", self.ckpt_dir)
        else:  # training mode, create a new ckpt folder
            if flags.model_name is None:  # leave custume name if possible
                self.ckpt_dir = os.path.join(
                    ckpt_dir, time.strftime('%Y%m%d_%H%M%S', time.localtime()))
            else:
                self.ckpt_dir = os.path.join(ckpt_dir, flags.model_name)
        self.model = self.create_model()  # The model itself
        self.loss = self.make_loss()  # The loss function
        self.optm = None  # The optimizer: Initialized at train()
        self.lr_scheduler = None  # The lr scheduler: Initialized at train()
        self.train_loader = train_loader  # The train data loader
        self.test_loader = test_loader  # The test data loader
        self.log = SummaryWriter(
            self.ckpt_dir
        )  # Create a summary writer for keeping the summary to the tensor board
        self.best_validation_loss = float('inf')  # Set the BVL to large number

    def create_model(self):
        """
        Function to create the network module from provided model fn and flags
        :return: the created nn module
        """
        model = self.model_fn(self.flags)
        #summary(model, input_size=(128, 8))
        print(model)
        return model

    def make_loss(self, logit=None, labels=None):
        """
        Create a tensor that represents the loss. This is consistant both at training time \
        and inference time for Backward model
        :param logit: The output of the network
        :return: the total loss
        """
        if logit is None:
            return None
        MSE_loss = nn.functional.mse_loss(logit, labels)  # The MSE Loss of the
        BDY_loss = 0  # Implemenation later in the backward propagation model
        return MSE_loss + BDY_loss

    def make_optimizer(self):
        """
        Make the corresponding optimizer from the flags. Only below optimizers are allowed. Welcome to add more
        :return:
        """
        if self.flags.optim == 'Adam':
            op = torch.optim.Adam(self.model.parameters(),
                                  lr=self.flags.lr,
                                  weight_decay=self.flags.reg_scale)
        elif self.flags.optim == 'RMSprop':
            op = torch.optim.RMSprop(self.model.parameters(),
                                     lr=self.flags.lr,
                                     weight_decay=self.flags.reg_scale)
        elif self.flags.optim == 'SGD':
            op = torch.optim.SGD(self.model.parameters(),
                                 lr=self.flags.lr,
                                 weight_decay=self.flags.reg_scale)
        else:
            raise Exception(
                "Your Optimizer is neither Adam, RMSprop or SGD, please change in param or contact Ben"
            )
        return op

    def make_lr_scheduler(self):
        """
        Make the learning rate scheduler as instructed. More modes can be added to this, current supported ones:
        1. ReduceLROnPlateau (decrease lr when validation error stops improving
        :return:
        """
        return lr_scheduler.ReduceLROnPlateau(optimizer=self.optm,
                                              mode='min',
                                              factor=self.flags.lr_decay_rate,
                                              patience=10,
                                              verbose=True,
                                              threshold=1e-4)

    def save(self):
        """
        Saving the model to the current check point folder with name best_model_forward.pt
        :return: None
        """
        #torch.save(self.model.state_dict, os.path.join(self.ckpt_dir, 'best_model_state_dict.pt'))
        torch.save(self.model,
                   os.path.join(self.ckpt_dir, 'best_model_forward.pt'))

    def load(self):
        """
        Loading the model from the check point folder with name best_model_forward.pt
        :return:
        """
        #self.model.load_state_dict(torch.load(os.path.join(self.ckpt_dir, 'best_model_state_dict.pt')))
        self.model = torch.load(
            os.path.join(self.ckpt_dir, 'best_model_forward.pt'))

    def train(self):
        """
        The major training function. This would start the training using information given in the flags
        :return: None
        """
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.model.cuda()

        # Construct optimizer after the model moved to GPU
        self.optm = self.make_optimizer()
        self.lr_scheduler = self.make_lr_scheduler()

        for epoch in range(self.flags.train_step):
            # print("This is training Epoch {}".format(epoch))
            # Set to Training Mode
            train_loss = []
            train_loss_eval_mode_list = []
            self.model.train()
            for j, (geometry, spectra) in enumerate(self.train_loader):
                if cuda:
                    geometry = geometry.cuda()  # Put data onto GPU
                    spectra = spectra.cuda()  # Put data onto GPU
                self.optm.zero_grad()  # Zero the gradient first
                logit = self.model(geometry)  # Get the output
                # print("logit type:", logit.dtype)
                # print("spectra type:", spectra.dtype)
                loss = self.make_loss(logit, spectra)  # Get the loss tensor
                loss.backward()  # Calculate the backward gradients
                # torch.nn.utils.clip_grad_value_(self.model.parameters(), 10)
                self.optm.step()  # Move one step the optimizer
                train_loss.append(np.copy(
                    loss.cpu().data.numpy()))  # Aggregate the loss

                #############################################
                # Extra test for err_test < err_train issue #
                #############################################
                self.model.eval()
                logit = self.model(geometry)  # Get the output
                loss = self.make_loss(logit, spectra)  # Get the loss tensor
                train_loss_eval_mode_list.append(
                    np.copy(loss.cpu().data.numpy()))
                self.model.train()

            # Calculate the avg loss of training
            train_avg_loss = np.mean(train_loss)
            train_avg_eval_mode_loss = np.mean(train_loss_eval_mode_list)

            if epoch % self.flags.eval_step == 0:  # For eval steps, do the evaluations and tensor board
                # Record the training loss to the tensorboard
                #train_avg_loss = train_loss.data.numpy() / (j+1)
                self.log.add_scalar('Loss/train', train_avg_loss, epoch)
                self.log.add_scalar('Loss/train_eval_mode',
                                    train_avg_eval_mode_loss, epoch)
                if self.flags.use_lorentz:
                    for j in range(self.flags.num_plot_compare):
                        f = self.compare_spectra(
                            Ypred=logit[j, :].cpu().data.numpy(),
                            Ytruth=spectra[j, :].cpu().data.numpy(),
                            E2=self.model.e2[j, :, :],
                            E1=self.model.e1[j, :, :],
                            eps_inf=self.model.eps_inf[j])
                        self.log.add_figure(tag='E1&E2{}'.format(j),
                                            figure=f,
                                            global_step=epoch)
                        f = self.compare_spectra(
                            Ypred=logit[j, :].cpu().data.numpy(),
                            Ytruth=spectra[j, :].cpu().data.numpy(),
                            N=self.model.N[j, :],
                            K=self.model.K[j, :])
                        self.log.add_figure(tag='N&K{}'.format(j),
                                            figure=f,
                                            global_step=epoch)
                        f = self.compare_spectra(
                            Ypred=logit[j, :].cpu().data.numpy(),
                            Ytruth=spectra[j, :].cpu().data.numpy(),
                            T=self.model.T_each_lor[j, :],
                            eps_inf=self.model.eps_inf[j])
                        self.log.add_figure(tag='T{}'.format(j),
                                            figure=f,
                                            global_step=epoch)
                    # For debugging purpose, in model:forward function reocrd the tensor
                    self.log.add_histogram("w0_histogram", self.model.w0s,
                                           epoch)
                    self.log.add_histogram("wp_histogram", self.model.wps,
                                           epoch)
                    self.log.add_histogram("g_histogram", self.model.gs, epoch)

                # Set to Evaluation Mode
                self.model.eval()
                print("Doing Evaluation on the model now")
                test_loss = []
                for j, (geometry, spectra) in enumerate(
                        self.test_loader):  # Loop through the eval set
                    if cuda:
                        geometry = geometry.cuda()
                        spectra = spectra.cuda()
                    logit = self.model(geometry)
                    loss = self.make_loss(logit, spectra)  # compute the loss
                    test_loss.append(np.copy(
                        loss.cpu().data.numpy()))  # Aggregate the loss

                # Record the testing loss to the tensorboard
                test_avg_loss = np.mean(test_loss)
                self.log.add_scalar('Loss/test', test_avg_loss, epoch)

                print("This is Epoch %d, training loss %.5f, validation loss %.5f" \
                      % (epoch, train_avg_loss, test_avg_loss ))

                # Model improving, save the model down
                if test_avg_loss < self.best_validation_loss:
                    self.best_validation_loss = test_avg_loss
                    self.save()
                    print("Saving the model down...")

                    if self.best_validation_loss < self.flags.stop_threshold:
                        print("Training finished EARLIER at epoch %d, reaching loss of %.5f" %\
                              (epoch, self.best_validation_loss))
                        return None

            # Learning rate decay upon plateau
            self.lr_scheduler.step(train_avg_loss)
        self.log.close()

    def evaluate(self, save_dir='data/'):
        self.load()  # load the model as constructed
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.model.cuda()
        self.model.eval()  # Evaluation mode

        # Get the file names
        Ypred_file = os.path.join(save_dir,
                                  'test_Ypred_{}.csv'.format(self.saved_model))
        Xtruth_file = os.path.join(
            save_dir, 'test_Xtruth_{}.csv'.format(self.saved_model))
        Ytruth_file = os.path.join(
            save_dir, 'test_Ytruth_{}.csv'.format(self.saved_model))
        # Xpred_file = os.path.join(save_dir, 'test_Xpred_{}.csv'.format(self.saved_model))  # For pure forward model, there is no Xpred

        # Open those files to append
        with open(Xtruth_file,
                  'a') as fxt, open(Ytruth_file,
                                    'a') as fyt, open(Ypred_file, 'a') as fyp:
            # Loop through the eval data and evaluate
            for ind, (geometry, spectra) in enumerate(self.test_loader):
                if cuda:
                    geometry = geometry.cuda()
                    spectra = spectra.cuda()
                logits = self.model(geometry)
                np.savetxt(fxt, geometry.cpu().data.numpy(), fmt='%.3f')
                np.savetxt(fyt, spectra.cpu().data.numpy(), fmt='%.3f')
                np.savetxt(fyp, logits.cpu().data.numpy(), fmt='%.3f')
        return Ypred_file, Ytruth_file

    def compare_spectra(self,
                        Ypred,
                        Ytruth,
                        T=None,
                        title=None,
                        figsize=[15, 5],
                        T_num=10,
                        E1=None,
                        E2=None,
                        N=None,
                        K=None,
                        eps_inf=None):
        """
        Function to plot the comparison for predicted spectra and truth spectra
        :param Ypred:  Predicted spectra, this should be a list of number of dimension 300, numpy
        :param Ytruth:  Truth spectra, this should be a list of number of dimension 300, numpy
        :param title: The title of the plot, usually it comes with the time
        :param figsize: The figure size of the plot
        :return: The identifier of the figure
        """
        # Make the frequency into real frequency in THz
        fre_low = 0.8
        fre_high = 1.5
        frequency = fre_low + (fre_high -
                               fre_low) / len(Ytruth) * np.arange(300)
        f = plt.figure(figsize=figsize)
        plt.plot(frequency, Ypred, label='Pred')
        plt.plot(frequency, Ytruth, label='Truth')
        if T is not None:
            plt.plot(frequency, T, linewidth=1, linestyle='--')
        if E2 is not None:
            for i in range(np.shape(E2)[0]):
                plt.plot(frequency,
                         E2[i, :],
                         linewidth=1,
                         linestyle=':',
                         label="E2" + str(i))
        if E1 is not None:
            for i in range(np.shape(E1)[0]):
                plt.plot(frequency,
                         E1[i, :],
                         linewidth=1,
                         linestyle='-',
                         label="E1" + str(i))
        if N is not None:
            plt.plot(frequency, N, linewidth=1, linestyle=':', label="N")
        if K is not None:
            plt.plot(frequency, K, linewidth=1, linestyle='-', label="K")
        if eps_inf is not None:
            plt.plot(frequency,
                     np.ones(np.shape(frequency)) * eps_inf,
                     label="eps_inf")
        # plt.ylim([0, 1])
        plt.legend()
        #plt.xlim([fre_low, fre_high])
        plt.xlabel("Frequency (THz)")
        plt.ylabel("Transmittance")
        if title is not None:
            plt.title(title)
        return f
def estimate(X_train, y_train):
    i = 0
    ii = 0
    nrows = 256
    ncolumns = 256
    channels = 1
    ntrain = 0.8 * len(X_train)
    nval = 0.2 * len(X_train)
    batch_size = 16
    epochs = 2
    # Number of classes
    num_cpu = multiprocessing.cpu_count()
    num_classes = 2
    torch.manual_seed(8)
    torch.cuda.manual_seed(8)
    np.random.seed(8)
    random.seed(8)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    X = []
    X_train = np.reshape(np.array(X_train), [
        len(X_train),
    ])
    for img in list(range(0, len(X_train))):
        if X_train[img].ndim >= 3:
            X.append(
                np.moveaxis(
                    cv2.resize(X_train[img][:, :, :3], (nrows, ncolumns),
                               interpolation=cv2.INTER_CUBIC), -1, 0))
        else:
            smimg = cv2.cvtColor(X_train[img], cv2.COLOR_GRAY2RGB)
            X.append(
                np.moveaxis(
                    cv2.resize(smimg, (nrows, ncolumns),
                               interpolation=cv2.INTER_CUBIC), -1, 0))

        if y_train[img] == 'COVID':
            y_train[img] = 1
        elif y_train[img] == 'NonCOVID':
            y_train[img] = 0
        else:
            continue

    x = np.array(X)
    y_train = np.array(y_train)

    outputs_all = []
    labels_all = []

    X_train, X_val, y_train, y_val = train_test_split(x,
                                                      y_train,
                                                      test_size=0.2,
                                                      random_state=2)

    image_transforms = {
        'train':
        transforms.Compose([
            transforms.Lambda(lambda x: x / 255),
            transforms.ToPILImage(),
            transforms.Resize((230, 230)),
            transforms.RandomResizedCrop((224), scale=(0.75, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            #transforms.Affine(10,shear =(0.1,0.1)),
            # random brightness and random contrast
            #transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.45271412, 0.45271412, 0.45271412],
                                 [0.33165374, 0.33165374, 0.33165374])
        ]),
        'valid':
        transforms.Compose([
            transforms.Lambda(lambda x: x / 255),
            transforms.ToPILImage(),
            transforms.Resize((230, 230)),
            transforms.CenterCrop(size=224),
            transforms.ToTensor(),
            transforms.Normalize([0.45271412, 0.45271412, 0.45271412],
                                 [0.33165374, 0.33165374, 0.33165374])
        ])
    }

    train_data = MyDataset(X_train, y_train, image_transforms['train'])

    valid_data = MyDataset(X_val, y_val, image_transforms['valid'])

    dataset_sizes = {'train': len(train_data), 'valid': len(valid_data)}

    dataloaders = {
        'train':
        data.DataLoader(train_data,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_cpu,
                        pin_memory=True,
                        worker_init_fn=np.random.seed(7),
                        drop_last=False),
        'valid':
        data.DataLoader(valid_data,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_cpu,
                        pin_memory=True,
                        worker_init_fn=np.random.seed(7),
                        drop_last=False)
    }

    model = DenseNet121(num_classes, pretrained=True)

    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3]).cuda()
    #print(model)
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.SGD(model.parameters(), lr=0.06775, momentum=0.5518,weight_decay=0.000578)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.05)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=35, gamma=0.1)

    best_acc = -1
    best_f1 = 0.0
    best_epoch = 0
    best_loss = 100000
    since = time.time()
    writer = SummaryWriter()

    model.train()

    for epoch in range(epochs):
        print('epoch', epoch)
        jj = 0
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            predictions = FloatTensor()
            all_labels = FloatTensor()

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                predictions = predictions.to(device, non_blocking=True)
                all_labels = all_labels.to(device, non_blocking=True)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    predictions = torch.cat([predictions, preds.float()])
                    all_labels = torch.cat([all_labels, labels.float()])

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    if phase == 'train':
                        jj += 1

                        if len(inputs) >= 16:

                            writer.add_figure(
                                'predictions vs. actuals epoch ' + str(epoch) +
                                ' ' + str(jj),
                                plot_classes_preds(model, inputs, labels))

            # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_f1 = f1_score(all_labels.tolist(),
                                predictions.tolist(),
                                average='weighted')

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = accuracy_score(all_labels.tolist(),
                                       predictions.tolist())

            if phase == 'train':
                writer.add_scalar('Train/Loss', epoch_loss, epoch)
                writer.add_scalar('Train/Accuracy', epoch_acc, epoch)

                writer.flush()
            elif phase == 'valid':
                writer.add_scalar('Valid/Loss', epoch_loss, epoch)
                writer.add_scalar('Valid/Accuracy', epoch_acc, epoch)
                writer.flush()

        # deep copy the model
            if phase == 'valid' and epoch_acc > best_acc:
                print('dffffffffffffffffffffffff')
                best_f1 = epoch_f1
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                best_model_wts = copy.deepcopy(model.module.state_dict())
                best_model_wts_module = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts_module)
    torch.save(model, "Model_densenet121.pth")
    torch.save(best_model_wts, "Model_densenet121_state.pth")
    time_elapsed = time.time() - since

    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best test Acc: {:4f}'.format(best_acc))
    print('Best test f1: {:4f}'.format(best_f1))
    print('best epoch: ', best_epoch)

    ## Replacing the last fully connected layer with SVM or ExtraTrees Classifiers
    model.module.fc = nn.Identity()

    for param in model.parameters():
        param.requires_grad_(False)

    clf = svm.SVC(kernel='rbf', probability=True)
    all_best_accs = {}
    all_best_f1s = {}
    #clf = ExtraTreesClassifier(n_estimators=40, max_depth=None, min_samples_split=30, random_state=0)

    for phase in ['train', 'valid']:
        outputs_all = []
        labels_all = []
        model.eval()  # Set model to evaluate mode

        # Iterate over data.
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(inputs)
            outputs_all.append(outputs)
            labels_all.append(labels)

        outputs = torch.cat(outputs_all)
        labels = torch.cat(labels_all)

        # fit the classifier on training set and then predict on test
        if phase == 'train':
            clf.fit(outputs.cpu(), labels.cpu())
            filename = 'classifier_model.sav'
            joblib.dump(clf, filename)
            all_best_accs[phase] = accuracy_score(labels.cpu(),
                                                  clf.predict(outputs.cpu()))
            all_best_f1s[phase] = f1_score(labels.cpu(),
                                           clf.predict(outputs.cpu()))
            print(phase, ' ',
                  accuracy_score(labels.cpu(), clf.predict(outputs.cpu())))
        if phase != 'train':
            predict = clf.predict(outputs.cpu())
            all_best_accs[phase] = accuracy_score(labels.cpu(),
                                                  clf.predict(outputs.cpu()))
            all_best_f1s[phase] = f1_score(labels.cpu(),
                                           clf.predict(outputs.cpu()))
            print(phase, ' ',
                  accuracy_score(labels.cpu(), clf.predict(outputs.cpu())))

    print('Best Acc: ', all_best_accs)
    print('Best f1: ', all_best_f1s)

    return model
예제 #9
0
파일: apt.py 프로젝트: yyht/lfi
class APT:
    """
    Implementation of
    'Automatic Posterior Transformation for Likelihood-free Inference'
    Greenberg et al.
    ICML 2019
    https://arxiv.org/abs/1905.07488
    """

    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        neural_posterior,
        num_atoms=-1,
        use_combined_loss=False,
        train_with_mcmc=False,
        mcmc_method="slice-np",
        summary_net=None,
        retrain_from_scratch_each_round=False,
        discard_prior_samples=False,
        summary_writer=None,
    ):
        """
        :param simulator:
            Python object with 'simulate' method which takes a torch.Tensor
            of parameter values, and returns a simulation result for each parameter
            as a torch.Tensor.
        :param prior: Distribution
            Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor [observation_dim] or [1, observation_dim]
            True observation x0 for which to perform inference on the posterior p(theta | x0).
        :param neural_posterior: nets.Module
            Conditional density estimator q(theta | x) with 'log_prob' and 'sample' methods.
        :param num_atoms: int
            Number of atoms to use for classification.
            If -1, use all other parameters in minibatch.
        :param use_combined_loss: bool
            Whether to jointly train prior samples using maximum likelihood.
            Useful to prevent density leaking when using box uniform priors.
        :param train_with_mcmc: bool
            Whether to sample using MCMC instead of i.i.d. sampling at the end of each round
        :param mcmc_method: str
            MCMC method to use if 'train_with_mcmc' is True.
            One of ['slice-numpy', 'hmc', 'nuts'].
        :param summary_net: nets.Module
            Optional network which may be used to produce feature vectors
            f(x) for high-dimensional observations.
        :param retrain_from_scratch_each_round: bool
            Whether to retrain the conditional density estimator for the posterior
            from scratch each round.
        :param discard_prior_samples: bool
            Whether to discard prior samples from round two onwards.
        :param summary_writer: SummaryWriter
            Optionally pass summary writer.
            If None, will create one internally.
        """

        self._simulator = simulator
        self._prior = prior
        self._true_observation = true_observation
        self._neural_posterior = neural_posterior

        assert isinstance(num_atoms, int), "Number of atoms must be an integer."
        self._num_atoms = num_atoms

        self._use_combined_loss = use_combined_loss

        # We may want to summarize high-dimensional observations.
        # This may be either a fixed or learned transformation.
        if summary_net is None:
            self._summary_net = nn.Identity()
        else:
            self._summary_net = summary_net

        self._mcmc_method = mcmc_method
        self._train_with_mcmc = train_with_mcmc

        # HMC and NUTS from Pyro.
        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            neural_posterior, prior, self._true_observation
        )

        # Axis-aligned slice sampling implementation in NumPy
        target_log_prob = (
            lambda parameters: self._neural_posterior.log_prob(
                inputs=torch.Tensor(parameters).reshape(1, -1),
                context=self._true_observation.reshape(1, -1),
            ).item()
            if not np.isinf(self._prior.log_prob(torch.Tensor(parameters)).sum().item())
            else -np.inf
        )
        self._neural_posterior.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1,))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._neural_posterior.train()

        self._retrain_from_scratch_each_round = retrain_from_scratch_each_round
        # If we're retraining from scratch each round,
        # keep a copy of the original untrained model for reinitialization.
        self._untrained_neural_posterior = deepcopy(neural_posterior)

        self._discard_prior_samples = discard_prior_samples

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank, self._prior_masks = [], [], []

        self._model_bank = []

        self._total_num_generated_examples = 0

        # Each APT run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(
                utils.get_log_root(), "apt", simulator.name, utils.get_timestamp()
            )
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
            "rejection-sampling-acceptance-rates": [],
        }

    def run_inference(self, num_rounds, num_simulations_per_round):
        """
        This runs APT for num_rounds rounds, using num_simulations_per_round calls to
        the simulator per round.

        :param num_rounds: Number of rounds to run.
        :param num_simulations_per_round: Number of simulator calls per round.
        :return: None
        """

        round_description = ""
        tbar = tqdm(range(num_rounds))
        for round_ in tbar:

            tbar.set_description(round_description)

            # Generate parameters from prior in first round, and from most recent posterior
            # estimate in subsequent rounds.
            if round_ == 0:
                parameters, observations = simulators.simulation_wrapper(
                    simulator=self._simulator,
                    parameter_sample_fn=lambda num_samples: self._prior.sample(
                        (num_samples,)
                    ),
                    num_samples=num_simulations_per_round,
                )
            else:
                parameters, observations = simulators.simulation_wrapper(
                    simulator=self._simulator,
                    parameter_sample_fn=lambda num_samples: self.sample_posterior_mcmc(
                        num_samples
                    )
                    if self._train_with_mcmc
                    else self.sample_posterior(num_samples),
                    num_samples=num_simulations_per_round,
                )

            # Store (parameter, observation) pairs.
            self._parameter_bank.append(torch.Tensor(parameters))
            self._observation_bank.append(torch.Tensor(observations))
            self._prior_masks.append(
                torch.ones(num_simulations_per_round, 1)
                if round_ == 0
                else torch.zeros(num_simulations_per_round, 1)
            )

            # Fit posterior using newly aggregated data set.
            self._fit_posterior(round_=round_)

            # Store models at end of each round.
            self._model_bank.append(deepcopy(self._neural_posterior))
            self._model_bank[-1].eval()

            # Update description for progress bar.
            round_description = (
                f"-------------------------\n"
                f"||||| ROUND {round_ + 1} STATS |||||:\n"
                f"-------------------------\n"
                f"Epochs trained: {self._summary['epochs'][-1]}\n"
                f"Best validation performance: {self._summary['best-validation-log-probs'][-1]:.4f}\n\n"
            )

            # Update tensorboard and summary dict.
            self._summarize(round_)

    def sample_posterior(self, num_samples, true_observation=None):
        """
        Samples from posterior for true observation q(theta | x0) using most recent
        posterior estimate.

        :param num_samples: int
            Number of samples to generate.
        :param true_observation: torch.Tensor [observation_dim] or [1, observation_dim]
            Optionally pass true observation for inference.
            Otherwise uses true observation given at instantiation.
        :return: torch.Tensor [num_samples, parameter_dim]
            Posterior parameter samples.
        """

        true_observation = (
            true_observation if true_observation is not None else self._true_observation
        )

        # Always sample in eval mode.
        self._neural_posterior.eval()

        # Rejection sampling is potentially needed for the posterior.
        # This is because the prior may not have support everywhere.
        # The posterior may also be constrained to the same support,
        # but we don't know this a priori.
        samples = []
        num_remaining_samples = num_samples
        total_num_accepted, self._total_num_generated_examples = 0, 0
        while num_remaining_samples > 0:

            # Generate samples from posterior.
            candidate_samples = self._neural_posterior.sample(
                max(10000, num_samples), context=true_observation.reshape(1, -1)
            ).squeeze(0)

            # Evaluate posterior samples under the prior.
            prior_log_prob = self._prior.log_prob(candidate_samples)
            if isinstance(self._prior, distributions.Uniform):
                prior_log_prob = prior_log_prob.sum(-1)

            # Keep those samples which have non-zero probability under the prior.
            accepted_samples = candidate_samples[~torch.isinf(prior_log_prob)]
            samples.append(accepted_samples.detach())

            # Update remaining number of samples needed.
            num_accepted = (~torch.isinf(prior_log_prob)).sum().item()
            num_remaining_samples -= num_accepted
            total_num_accepted += num_accepted

            # Keep track of acceptance rate
            self._total_num_generated_examples += candidate_samples.shape[0]

        # Back to training mode.
        self._neural_posterior.train()

        # Aggregate collected samples.
        samples = torch.cat(samples)

        # Make sure we have the right amount.
        samples = samples[:num_samples, ...]
        assert samples.shape[0] == num_samples

        return samples

    def sample_posterior_mcmc(self, num_samples, thin=10):
        """
        Samples from posterior for true observation q(theta | x0) using MCMC.

        :param num_samples: Number of samples to generate.
        :param mcmc_method: Which MCMC method to use ['metropolis-hastings', 'slice', 'hmc', 'nuts']
        :param thin: Generate (num_samples * thin) samples in total, then select every
        'thin' sample.
        :return: torch.Tensor of shape [num_samples, parameter_dim]
        """

        # Always sample in eval mode.
        self._neural_posterior.eval()

        if self._mcmc_method == "slice-np":
            self.posterior_sampler.gen(20)  # Burn-in for 200 samples
            samples = torch.Tensor(self.posterior_sampler.gen(num_samples))

        else:
            if self._mcmc_method == "slice":
                kernel = Slice(potential_function=self._potential_function)
            elif self._mcmc_method == "hmc":
                kernel = HMC(potential_fn=self._potential_function)
            elif self._mcmc_method == "nuts":
                kernel = NUTS(potential_fn=self._potential_function)
            else:
                raise ValueError(
                    "'mcmc_method' must be one of ['slice', 'hmc', 'nuts']."
                )
            num_chains = mp.cpu_count() - 1

            initial_params = self._prior.sample((num_chains,))
            sampler = MCMC(
                kernel=kernel,
                num_samples=(thin * num_samples) // num_chains + num_chains,
                warmup_steps=200,
                initial_params={"": initial_params},
                num_chains=num_chains,
                mp_context="spawn",
            )
            sampler.run()
            samples = next(iter(sampler.get_samples().values())).reshape(
                -1, self._simulator.parameter_dim
            )

            samples = samples[::thin][:num_samples]
            assert samples.shape[0] == num_samples

        # Back to training mode.
        self._neural_posterior.train()

        return samples

    def _fit_posterior(
        self,
        round_,
        batch_size=100,
        learning_rate=5e-4,
        validation_fraction=0.1,
        stop_after_epochs=20,
        clip_grad_norm=True,
    ):
        """
        Trains the conditional density estimator for the posterior by maximizing the
        proposal posterior using the most recently aggregated bank of (parameter, observation)
        pairs.
        Uses early stopping on a held-out validation set as a terminating condition.

        :param round_: int
            Which round we're currently in. Needed when sampling procedure is
            not simply sampling from (proposal) marginal.
        :param batch_size: int
            Size of batch to use for training.
        :param learning_rate: float
            Learning rate for Adam optimizer.
        :param validation_fraction: float in [0, 1]
            The fraction of data to use for validation.
        :param stop_after_epochs: int
            The number of epochs to wait for improvement on the
            validation set before terminating training.
        :param clip_grad_norm: bool
            Whether to clip norm of gradients or not.
        :return: None
        """

        if self._discard_prior_samples and (round_ > 0):
            ix = 1
        else:
            ix = 0

        # Get total number of training examples.
        num_examples = torch.cat(self._parameter_bank[ix:]).shape[0]

        # Select random train and validation splits from (parameter, observation) pairs.
        permuted_indices = torch.randperm(num_examples)
        num_training_examples = int((1 - validation_fraction) * num_examples)
        num_validation_examples = num_examples - num_training_examples
        train_indices, val_indices = (
            permuted_indices[:num_training_examples],
            permuted_indices[num_training_examples:],
        )

        # Dataset is shared for training and validation loaders.
        dataset = data.TensorDataset(
            torch.cat(self._parameter_bank[ix:]),
            torch.cat(self._observation_bank[ix:]),
            torch.cat(self._prior_masks[ix:]),
        )

        # Create train and validation loaders using a subset sampler.
        train_loader = data.DataLoader(
            dataset,
            batch_size=batch_size,
            drop_last=True,
            sampler=SubsetRandomSampler(train_indices),
        )
        val_loader = data.DataLoader(
            dataset,
            batch_size=min(batch_size, num_examples - num_training_examples),
            shuffle=False,
            drop_last=False,
            sampler=SubsetRandomSampler(val_indices),
        )

        optimizer = optim.Adam(
            list(self._neural_posterior.parameters())
            + list(self._summary_net.parameters()),
            lr=learning_rate,
        )
        # Keep track of best_validation log_prob seen so far.
        best_validation_log_prob = -1e100
        # Keep track of number of epochs since last improvement.
        epochs_since_last_improvement = 0
        # Keep track of model with best validation performance.
        best_model_state_dict = None

        # If we're retraining from scratch each round, reset the neural posterior
        # to the untrained copy we made at the start.
        if self._retrain_from_scratch_each_round and round_ > 0:
            # self._neural_posterior = deepcopy(self._untrained_neural_posterior)
            self._neural_posterior = deepcopy(self._model_bank[0])

        def _get_log_prob_proposal_posterior(inputs, context, masks):
            """
            We have two main options when evaluating the proposal posterior.
            (1) Generate atoms from the proposal prior.
            (2) Generate atoms from a more targeted distribution,
            such as the most recent posterior.
            If we choose the latter, it is likely beneficial not to do this in the first
            round, since we would be sampling from a randomly initialized neural density
            estimator.

            :param inputs: torch.Tensor Batch of parameters.
            :param context: torch.Tensor Batch of observations.
            :return: torch.Tensor [1] log_prob_proposal_posterior
            """

            log_prob_posterior_non_atomic = self._neural_posterior.log_prob(
                inputs, context
            )

            # just do maximum likelihood in the first round
            if round_ == 0:
                return log_prob_posterior_non_atomic

            num_atoms = self._num_atoms if self._num_atoms > 0 else batch_size

            # Each set of parameter atoms is evaluated using the same observation,
            # so we repeat rows of the context.
            # e.g. [1, 2] -> [1, 1, 2, 2]
            repeated_context = utils.repeat_rows(context, num_atoms)

            # To generate the full set of atoms for a given item in the batch,
            # we sample without replacement num_atoms - 1 times from the rest
            # of the parameters in the batch.
            assert 0 < num_atoms - 1 < batch_size
            probs = (
                (1 / (batch_size - 1))
                * torch.ones(batch_size, batch_size)
                * (1 - torch.eye(batch_size))
            )
            choices = torch.multinomial(
                probs, num_samples=num_atoms - 1, replacement=False
            )
            contrasting_inputs = inputs[choices]

            # We can now create our sets of atoms from the contrasting parameter sets
            # we have generated.
            atomic_inputs = torch.cat(
                (inputs[:, None, :], contrasting_inputs), dim=1
            ).reshape(batch_size * num_atoms, -1)

            # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals.
            log_prob_posterior = self._neural_posterior.log_prob(
                atomic_inputs, repeated_context
            )
            assert utils.notinfnotnan(
                log_prob_posterior
            ), "NaN/inf detected in posterior eval."
            log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms)

            # Get (batch_size * num_atoms) log prob prior evals.
            if isinstance(self._prior, distributions.Uniform):
                log_prob_prior = self._prior.log_prob(atomic_inputs).sum(-1)
                # log_prob_prior = torch.zeros(log_prob_prior.shape)
            else:
                log_prob_prior = self._prior.log_prob(atomic_inputs)
            log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms)
            assert utils.notinfnotnan(log_prob_prior), "NaN/inf detected in prior eval."

            # Compute unnormalized proposal posterior.
            unnormalized_log_prob_proposal_posterior = (
                log_prob_posterior - log_prob_prior
            )

            # Normalize proposal posterior across discrete set of atoms.
            log_prob_proposal_posterior = unnormalized_log_prob_proposal_posterior[
                :, 0
            ] - torch.logsumexp(unnormalized_log_prob_proposal_posterior, dim=-1)
            assert utils.notinfnotnan(
                log_prob_proposal_posterior
            ), "NaN/inf detected in proposal posterior eval."

            if self._use_combined_loss:
                masks = masks.reshape(-1)

                log_prob_proposal_posterior = (
                    masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior
                )

            return log_prob_proposal_posterior

        epochs = 0
        while True:

            # Train for a single epoch.
            self._neural_posterior.train()
            for batch in train_loader:
                optimizer.zero_grad()
                inputs, context, masks = (
                    batch[0].to(device),
                    batch[1].to(device),
                    batch[2].to(device),
                )
                summarized_context = self._summary_net(context)
                log_prob_proposal_posterior = _get_log_prob_proposal_posterior(
                    inputs, summarized_context, masks
                )
                loss = -torch.mean(log_prob_proposal_posterior)
                loss.backward()
                if clip_grad_norm:
                    clip_grad_norm_(self._neural_posterior.parameters(), max_norm=5.0)
                optimizer.step()

            epochs += 1

            # Calculate validation performance.
            self._neural_posterior.eval()
            log_prob_sum = 0
            with torch.no_grad():
                for batch in val_loader:
                    inputs, context, masks = (
                        batch[0].to(device),
                        batch[1].to(device),
                        batch[2].to(device),
                    )
                    log_prob = _get_log_prob_proposal_posterior(inputs, context, masks)
                    log_prob_sum += log_prob.sum().item()
            validation_log_prob = log_prob_sum / num_validation_examples

            # Check for improvement in validation performance over previous epochs.
            if validation_log_prob > best_validation_log_prob:
                best_validation_log_prob = validation_log_prob
                epochs_since_last_improvement = 0
                best_model_state_dict = deepcopy(self._neural_posterior.state_dict())
            else:
                epochs_since_last_improvement += 1

            # If no validation improvement over many epochs, stop training.
            if epochs_since_last_improvement > stop_after_epochs - 1:
                self._neural_posterior.load_state_dict(best_model_state_dict)
                break

        # Update summary.
        self._summary["epochs"].append(epochs)
        self._summary["best-validation-log-probs"].append(best_validation_log_prob)

    def _estimate_acceptance_rate(self, num_samples=int(1e7), true_observation=None):
        """
        Estimates rejection sampling acceptance rates.

        :param num_samples:
            Number of samples to use.
        :param true_observation:
            Observation on which to condition.
            If None, use true observation given at initialization.
        :return: float in [0, 1]
            Fraction of accepted samples.
        """
        true_observation = (
            true_observation if true_observation is not None else self._true_observation
        )

        # Always sample in eval mode.
        self._neural_posterior.eval()

        total_num_accepted_samples, total_num_generated_samples = 0, 0
        while total_num_generated_samples < num_samples:

            # Generate samples from posterior.
            candidate_samples = self._neural_posterior.sample(
                10000, context=true_observation.reshape(1, -1)
            ).squeeze(0)

            # Evaluate posterior samples under the prior.
            prior_log_prob = self._prior.log_prob(candidate_samples)
            if isinstance(self._prior, distributions.Uniform):
                prior_log_prob = prior_log_prob.sum(-1)

            # Update remaining number of samples needed.
            num_accepted_samples = (~torch.isinf(prior_log_prob)).sum().item()
            total_num_accepted_samples += num_accepted_samples

            # Keep track of acceptance rate
            total_num_generated_samples += candidate_samples.shape[0]

        # Back to training mode.
        self._neural_posterior.train()

        return total_num_accepted_samples / total_num_generated_samples

    @property
    def summary(self):
        return self._summary

    def _summarize(self, round_):

        # Update summaries.
        try:
            mmd = utils.unbiased_mmd_squared(
                self._parameter_bank[-1],
                self._simulator.get_ground_truth_posterior_samples(num_samples=1000),
            )
            self._summary["mmds"].append(mmd.item())
        except:
            pass

        # Median |x - x0| for most recent round.
        median_observation_distance = torch.median(
            torch.sqrt(
                torch.sum(
                    (self._observation_bank[-1] - self._true_observation.reshape(1, -1))
                    ** 2,
                    dim=-1,
                )
            )
        )
        self._summary["median-observation-distances"].append(
            median_observation_distance.item()
        )

        # KDE estimate of negative log prob true parameters using
        # parameters from most recent round.
        negative_log_prob_true_parameters = -utils.gaussian_kde_log_eval(
            samples=self._parameter_bank[-1],
            query=self._simulator.get_ground_truth_parameters().reshape(1, -1),
        )
        self._summary["negative-log-probs-true-parameters"].append(
            negative_log_prob_true_parameters.item()
        )

        # Rejection sampling acceptance rate
        rejection_sampling_acceptance_rate = self._estimate_acceptance_rate()
        self._summary["rejection-sampling-acceptance-rates"].append(
            rejection_sampling_acceptance_rate
        )

        # Plot most recently sampled parameters.
        parameters = utils.tensor2numpy(self._parameter_bank[-1])
        figure = utils.plot_hist_marginals(
            data=parameters,
            ground_truth=utils.tensor2numpy(
                self._simulator.get_ground_truth_parameters()
            ).reshape(-1),
            lims=self._simulator.parameter_plotting_limits,
        )

        # Write quantities using SummaryWriter.
        self._summary_writer.add_figure(
            tag="posterior-samples", figure=figure, global_step=round_ + 1
        )

        self._summary_writer.add_scalar(
            tag="epochs-trained",
            scalar_value=self._summary["epochs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="best-validation-log-prob",
            scalar_value=self._summary["best-validation-log-probs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="median-observation-distance",
            scalar_value=self._summary["median-observation-distances"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="negative-log-prob-true-parameters",
            scalar_value=self._summary["negative-log-probs-true-parameters"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="rejection-sampling-acceptance-rate",
            scalar_value=self._summary["rejection-sampling-acceptance-rates"][-1],
            global_step=round_ + 1,
        )

        if self._summary["mmds"]:
            self._summary_writer.add_scalar(
                tag="mmd",
                scalar_value=self._summary["mmds"][-1],
                global_step=round_ + 1,
            )

        self._summary_writer.flush()
예제 #10
0
class Trainer:
    def __init__(self,
                 model,
                 model_name,
                 train_dataloader,
                 val_dataloader=None,
                 mode='predict_same',
                 device=None,
                 version=1,
                 log_dir='./logs',
                 learning_rate=1e-3,
                 custom_time=None):
        self.model_name = model_name
        self.version = version

        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

        self.mode = mode

        self.device = device
        if self.device is None:
            self.device = torch.device('cpu')

        self.custom_time = custom_time
        if self.custom_time is not None:
            self.custom_time = torch.Tensor(custom_time).to(self.device)

        self.model = model.to(self.device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.logger = SummaryWriter(log_dir=log_dir)

    def process_epoch(self, dataloader, train=True):
        losses, mses, maes = [], [], []
        self.model.train(train)
        for batch in dataloader:
            x = batch[0].to(self.device)
            t = batch[1].to(self.device)
            mask = batch[2].to(self.device)

            x_to_predict, t_to_predict, mask_to_predict = None, None, None
            if len(batch) == 6:
                x_to_predict = batch[3].to(self.device)
                t_to_predict = batch[4].to(self.device)
                mask_to_predict = batch[5].to(self.device)

            loss, mse, mae = self.model.compute_loss(x,
                                                     t,
                                                     mask,
                                                     x_to_predict,
                                                     t_to_predict,
                                                     mask_to_predict,
                                                     return_metrics=True)
            if train:
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

            losses.append(float(loss.cpu().detach()))
            mses.append(float(mse.cpu().detach()))
            maes.append(float(mae.cpu().detach()))

        return np.mean(losses), np.mean(mses), np.mean(maes)

    def eval_single_batch(self,
                          dataloader,
                          custom_time=None,
                          expand_batch=False):
        x, t, mask, pred, t_pred, mask_pred = (None, ) * 6

        self.model.eval()
        for batch in dataloader:
            x = batch[0].to(self.device)
            t = batch[1].to(self.device)
            mask = batch[2].to(self.device)

            if expand_batch:
                custom_time = batch[4].to(self.device)

            with torch.no_grad():
                pred = self.model.sample_similar(x,
                                                 t,
                                                 mask,
                                                 custom_time=custom_time)
            break

        if expand_batch:
            x = torch.cat([batch[0], batch[3]], dim=1)
            t = torch.cat([batch[1], batch[4]], dim=0)
            mask = torch.cat([batch[2], batch[5]], dim=1)
            mask_pred = batch[5]

        t_pred = custom_time if custom_time is not None else t

        return x.cpu(), t.cpu(), mask.cpu(), pred.cpu(), t_pred.cpu(
        ), mask_pred

    def draw_plot(self,
                  epoch,
                  x_real,
                  t_real,
                  x_pred,
                  t_pred=None,
                  mask_real=None,
                  mask_pred=None,
                  mode_real='scatter',
                  mode_pred='plot',
                  figure_name='Val/img',
                  vline_shift=None):
        if t_pred is None:
            t_pred = t_real
        if mask_real is None:
            mask_real = init_mask(x_real)
        if mask_pred is None:
            mask_pred = init_mask(x_pred)

        def add_plot(X, t, mask, mode, marker='o', ls='-'):
            for x, m, color in zip(X, mask, ['r', 'g', 'b']):
                if mode == 'scatter':
                    plt.scatter(t[m].cpu().detach().numpy().ravel(),
                                x[m].cpu().detach().numpy().ravel(),
                                color=color,
                                marker=marker)
                elif mode == 'plot':
                    plt.plot(t[m].cpu().detach().numpy().ravel(),
                             x[m].cpu().detach().numpy().ravel(),
                             color=color,
                             ls=ls)
                else:
                    raise 'Unknown plot mode {}'.format(mode)

        figure = plt.figure(figsize=(8, 8))

        add_plot(x_real, t_real, mask_real, mode_real, marker='o', ls='-')
        add_plot(x_pred, t_pred, mask_pred, mode_pred, marker='x', ls='--')

        if vline_shift is not None:
            plt.vlines(vline_shift,
                       min(x_real.min(), x_pred.min()).cpu() - 1,
                       max(x_real.max(), x_pred.max()).cpu() + 1)

        plt.grid()
        plt.ylim((min(x_real.min(), x_pred.min()).cpu() - 1,
                  max(x_real.max(), x_pred.max()).cpu() + 1))
        self.logger.add_figure(figure_name, figure, global_step=epoch)
        plt.close()

    def log_metrics(self, loss, mse, mae, epoch, type='Train'):
        if loss is not None:
            self.logger.add_scalar(type + '/Loss', np.mean(loss), epoch)
        if mse is not None:
            self.logger.add_scalar(type + '/MSE', np.mean(mse), epoch)
        if mae is not None:
            self.logger.add_scalar(type + '/MAE', np.mean(mae), epoch)

    def train(self, epoch_num, save=True):
        for epoch in tqdm(range(epoch_num)):
            loss, mse, mae = self.process_epoch(self.train_dataloader,
                                                train=True)
            self.log_metrics(loss, mse, mae, epoch, type='Train')

            if (epoch + 1) % 20 == 0:
                x, t, mask, pred, t_pred, mask_pred = self.eval_single_batch(
                    dataloader=self.train_dataloader, expand_batch=True)
                self.draw_plot(epoch,
                               x,
                               t,
                               pred,
                               t_pred=t_pred,
                               mask_real=mask,
                               mask_pred=mask_pred,
                               mode_real='scatter',
                               mode_pred='scatter',
                               figure_name='Train/img',
                               vline_shift=5.0)

            if self.val_dataloader is not None:
                loss, mse, mae = self.process_epoch(self.val_dataloader,
                                                    train=False)
                self.log_metrics(loss, mse, mae, epoch, type='Val')

                if (epoch + 1) % 20 == 0:
                    pred = self.eval_single_batch(
                        dataloader=self.val_dataloader,
                        custom_time=self.custom_time)[3]
                    x = self.val_dataloader.get_functions_values(
                        [0, 1, 2], self.custom_time.cpu())
                    self.draw_plot(epoch,
                                   x,
                                   self.custom_time,
                                   pred,
                                   mode_real='plot',
                                   mode_pred='plot',
                                   figure_name='Val/img',
                                   vline_shift=[5.0, 10.0])

        if save:
            self.save_model()

    def save_model(self):
        torch.save(self.model.state_dict(),
                   self.model_name + '_v{}.pt'.format(self.version))

    def load_model(self):
        self.model.load_state_dict(
            torch.load(self.model_name + '_v{}.pt'.format(self.version)))
        self.model.eval()
예제 #11
0
                'stat_value': best_stat,
                "d_model": D,
                "d_loss": d_loss,
                "d_optimizer": D_optimizer,
                "g_model": G,
                "g_loss": g_loss,
                "g_optimizer": G_optimizer,
            }, os.path.join(args.out_dir, f"models/{args.model_name}_best_for_mmd_checkpoint.pkl"))
        tb_writer.add_scalar("MMD", np.mean(stat_list), global_step=epoch)
        if epoch % 50 == 0 or epoch == args.n_epochs:
            print(f'Epoch-{epoch}; D_loss: {d_loss.data.cpu().numpy()}; G_loss: {g_loss.data.cpu().numpy()}')
            torch.save({
                'epoch': epoch,
                'stat_value': np.mean(stat_list),
                "d_model": D,
                "d_loss": d_loss,
                "d_optimizer": D_optimizer,
                "g_model": G,
                "g_loss": g_loss,
                "g_optimizer": G_optimizer,
            }, os.path.join(args.out_dir, f"models/{args.model_name}_epoch_{epoch}_checkpoint.pkl"))
            with torch.no_grad():
                noise = torch.rand(args.batch_size, args.seq_len, 1, device=args.device)
                fake_seq = G(noise, real_labels)
                _seq = fake_seq[0].cpu().numpy()  # batch_first :^)
#                 _label = _labels[0].cpu().numpy()
                fig = save_ecg_example(_seq, f"pictures/{args.model_name}_epoch_{epoch}_example")
                tb_writer.add_figure("generated_example", fig, global_step=epoch)

                # TODO use visualize func here
예제 #12
0
class TensorboardLogger(StrategyLogger):
    """
    The `TensorboardLogger` provides an easy integration with
    Tensorboard logging. Each monitored metric is automatically
    logged to Tensorboard.
    The user can inspect results in real time by appropriately launching
    tensorboard with `tensorboard --logdir=/path/to/tb_log_exp_name`.

    AWS's S3 buckets and (if tensorflow is installed) GCloud storage url are
    supported.

    If no parameters are provided, the default folder in which tensorboard
    log files are placed is "./runs/".
    .. note::
        We rely on PyTorch implementation of Tensorboard. If you
        don't have Tensorflow installed in your environment,
        tensorboard will tell you that it is running with reduced
        feature set. This should not impact on the logger performance.
    """

    def __init__(self, tb_log_dir: Union[str, Path] = "./tb_data",
                 filename_suffix: str = ''):
        """
        Creates an instance of the `TensorboardLogger`.

        :param tb_log_dir: path to the directory where tensorboard log file
            will be stored. Default to "./tb_data".
        :param filename_suffix: string suffix to append at the end of
            tensorboard log file. Default ''.
        """

        super().__init__()
        tb_log_dir = _make_path_if_local(tb_log_dir)
        self.writer = SummaryWriter(tb_log_dir,
                                    filename_suffix=filename_suffix)

    def __del__(self):
        self.writer.close()

    def log_metric(self, metric_value: MetricValue, callback: str):
        super().log_metric(metric_value, callback)
        name = metric_value.name
        value = metric_value.value

        if isinstance(value, AlternativeValues):
            value = value.best_supported_value(Image, Tensor, TensorImage,
                                               Figure, float, int)

        if isinstance(value, Figure):
            self.writer.add_figure(name, value,
                                   global_step=metric_value.x_plot)

        elif isinstance(value, Image):
            self.writer.add_image(name, to_tensor(value),
                                  global_step=metric_value.x_plot)

        elif isinstance(value, Tensor):
            self.writer.add_histogram(name, value,
                                      global_step=metric_value.x_plot)

        elif isinstance(value, (float, int)):
            self.writer.add_scalar(name, value,
                                   global_step=metric_value.x_plot)

        elif isinstance(value, TensorImage):
            self.writer.add_image(name, value.image,
                                  global_step=metric_value.x_plot)
예제 #13
0
class ModelBase(object):
    """The base model interface and the training logic."""
    def __init__(self,
                 num_classes,
                 num_features,
                 opt,
                 device='cpu',
                 visualizer=None):
        """
        The constructor

        :param num_classes: the total number of gesture classes
        :param num_features: the dimensionality of each feature vector
        :param opt: an instance of Options
        :param device: the compute device to use
        :param visualizer: an optional visualizer to plot the training progress
        """
        self._num_classes = num_classes
        self._num_features = num_features
        self._opt = opt
        self._visualizer = visualizer
        self._device = device

        self._random = Random(self._opt.seed)
        # The data loader instance (used for the training loop)
        self._data_loader = None
        # The data normalizer (scalar) that's used to transform the data from the range [-1, 1] to
        # the original feature scale.
        self._normalizer = None

        # The generator network and the optimizer
        self._generator = None
        self._optimizer = None

        # The name of the metrics to keep track of
        self.metric_names = []

        # Random noise for latent space representation
        self._latent = torch.FloatTensor(self._opt.batch_size,
                                         self._opt.resample_n,
                                         self._opt.latent_dim).to(self._device)

        self._stats = {}
        self._best_model = None
        self._best_model_metric = np.inf
        self._best_model_which_metric = None  # What key in self.stat to use to save the best model?

        # Create tensorboard writer
        if opt.use_tensorboard > 0:
            # Lazy-load tensorboard stuff
            from torch.utils.tensorboard import SummaryWriter
            self._tb_writer = SummaryWriter(self._opt.run_tb_dir,
                                            filename_suffix='.tfevents')
            print(
                F"Tensorboard logs will be dumped in '{self._tb_writer.get_logdir()}'"
            )
        else:
            self._tb_writer = None

    #
    # Public properties and function
    #

    @property
    def device(self):
        """Returns the device that this network is on."""
        return self._device

    @property
    def normalizer(self):
        """Returns the normalizer of this model."""
        return self._normalizer

    def generate(self, labels, latent_vector=None, unnormalize=False):
        """
        Generates a batch of fake samples of the given labels

        :param labels: the labels to generate the data for
        :param latent_vector: the latent vector to use for generation. If nothing is provided, a
            new latent vector is generated.
        :param unnormalize: whether the returned samples should be "unnormalized", i.e.
            transformed back to the original feature space scale.
        :return: a batch of generated data
        """
        labels_one_hot = self._to_one_hot(labels)
        curr_batch_size = labels_one_hot.shape[0]

        # Generate or reuse the latent vector
        if latent_vector is None:
            self._generate_new_latent(curr_batch_size)
            latent_vector = self._latent
        elif latent_vector.shape[0] != curr_batch_size:  # Sanity check
            raise ValueError(
                "The batch size of the provided latent vector does not match that of the labels vector."
            )

        result = self._generator(latent_vector, labels_one_hot)

        if unnormalize:
            result = self._normalizer.unnormalize_list(result)

        return result

    def run_training_loop(self, data_split):
        """
        Runs the training loop on the given data loader

        :param data_split: the split of data to train on
        """
        self._data_loader = data_split.get_data_loader()

        if self._normalizer is None:
            self._normalizer = data_split.normalizer

        print('Beginning training')

        for epoch in range(self._opt.epochs):
            self.begin_epoch()
            self._run_one_epoch(epoch)
            self.end_epoch()

            # Do some logging
            self._log_to_tensorboard('train', epoch)
            stats = self.get_stats()
            print(
                F"Epoch {epoch} \t\t(took {stats['deltatime']})\t"
                F"{self._best_model_which_metric}={stats[self._best_model_which_metric]}"
            )

            # Save a checkpoint (if enabled)
            if self._opt.checkpoint_frequency > 0 and \
               epoch > 0 and epoch % self._opt.checkpoint_frequency == 0:
                self.save(save_best=False, suffix=str(epoch))

        print("Training finished!")
        # Save the current state into a checkpoint
        self.save(save_best=False, suffix=str(self._opt.epochs))

    def bookkeep(self):
        """
        Bookkeep the stats.
        """
        self._stats['cnt'] += 1

        # Process all metrics
        for metric_name in self.metric_names:
            if hasattr(self, metric_name):
                attr = getattr(self, metric_name)

                if metric_name not in self._stats:
                    self._stats[metric_name] = []

                self._stats[metric_name].append(
                    attr.item() if attr is not None else 0)
            else:
                raise Exception(
                    F"The metric with the name {metric_name} was not found")

    def begin_epoch(self):
        """
        Mark the start of a training epoch.
        """
        self._stats = {'cnt': 0, 'start_time': time.time()}

        # Reset all metrics
        for metric_name in self.metric_names:
            if hasattr(self, metric_name):
                setattr(self, metric_name, None)

    def end_epoch(self):
        """
        Mark the end of a training epoch.
        """
        self._stats['stop_time'] = time.time()
        best_metric_candidate = self.get_stats()[self._best_model_which_metric]

        if best_metric_candidate < self._best_model_metric:
            self._best_model = self._get_state()
            self._best_model_metric = best_metric_candidate

    def get_stats(self):
        """
        :return: the stats collected during training
        """
        stats = copy.deepcopy(self._stats)

        for key in stats.keys():
            if key == 'cnt' or 'time' in key:
                continue

            stats[key] = np.mean(stats[key])

        # Calculate delta epoch time
        stats['deltatime'] = str(
            timedelta(seconds=stats['stop_time'] - stats['start_time']))
        del stats['start_time']
        del stats['stop_time']

        return stats

    def stat_str(self):
        """
        :return: the string representation of the collected training stats.
        """
        stat = self.get_stats()
        result = F"{{'best_{self._best_model_which_metric}': {self._best_model_metric}}} {str(stat)}"
        return result

    def save(self, save_best=True, suffix=None):
        """
        Saves the model into a checkpoint file.

        :param save_best: if `True`, will save the model yielding the best metric.
        :param suffix: an optional suffix to use for the saved filename.
        """
        # Decide on the model and the filename
        if save_best:
            model = self._best_model
            fname = 'checkpoint-best.tar'
        else:
            model = self._get_state()
            fname = F'checkpoint-{suffix}.tar' if suffix is not None else 'checkpoint.tar'

        path = os.path.join(self._opt.run_checkpoint_dir, fname)
        print(
            F"Saving the {'best ' if save_best else ''}checkpoint in '{path}'")
        torch.save(model, path)

    def load(self, path):
        """
        Loads the model stored in a checkpoint file.

        :param path: the path of the checkpoint file.
        """
        print(F"Loading the checkpoint from '{path}'...")
        loaded = torch.load(path)
        self._load_state(loaded)

    #
    # Private function
    #

    def _to_device(self):
        """
        Transfer everything to the correct computation device
        """
        raise NotImplementedError()

    def _to_one_hot(self, labels):
        """
        Converts a list of labels to their one-hot representation with correct format for label-conditioning

        :param labels: input labels
        :return: the one-hot representation of the input labels
        """
        one_hot_converted = one_hot(labels,
                                    self._num_classes).to(self._device,
                                                          dtype=torch.float32,
                                                          non_blocking=True)
        return one_hot_converted.unsqueeze(1).expand(-1, self._opt.resample_n,
                                                     -1)

    def _generate_new_latent(self, batch_size):
        """
        Generates a new latent vector and stores internally.

        :param batch_size: the batch size of the generated latent
        """
        self._latent.resize_(batch_size, self._opt.resample_n,
                             self._opt.latent_dim).normal_(0, 1)
        # Some PyTorch versions have a bug that results in the generation of NaN values after the above
        # operation. Here, we just replace those NaN values with zeros.
        self._latent[torch.isnan(self._latent)] = 0

    def _log_to_tensorboard(self, suffix, epoch):
        """
        Logs the current training progress to tensorboard (if enabled).

        :param suffix: the suffix to use for logging
        :param epoch:  the epoch counter
        """
        if not self._opt.use_tensorboard:
            return

        stat = self.get_stats()

        # Log the overall best metric
        self._tb_writer.add_scalar(
            F'best_{self._best_model_which_metric}_{suffix}',
            self._best_model_metric, epoch)
        # Log the metric that we track and save as the best metric
        self._tb_writer.add_scalar(F'{self._best_model_which_metric}_{suffix}',
                                   stat[self._best_model_which_metric], epoch)

        # Log the metric values
        self._tb_writer.add_scalar(F'{self._best_model_which_metric}_{suffix}',
                                   stat[self._best_model_which_metric], epoch)

        for metric_name in self.metric_names:
            if metric_name in stat:
                self._tb_writer.add_scalar(F'{metric_name}_{suffix}',
                                           stat[metric_name], epoch)

        # Should we visualize?
        if self._opt.vis_frequency > 0 and \
                epoch % self._opt.vis_frequency == 0 and \
                self._visualizer is not None and \
                self._num_features == 2:  # Only visualize if the features are 2D
            # Visualize the constant latent
            self._visualizer.visualize(self, self._data_loader)
            self._tb_writer.add_figure('const_latent', self._visualizer.fig,
                                       epoch)

    def _run_one_epoch(self, epoch):
        """
        Runs a single epoch (or step) of training.

        :param epoch: the epoch counter
        """
        raise NotImplementedError()

    def _get_state(self):
        """
        Helper function to get the internal state of the model.

        :return: a dictionary containing the internal state of the model.
        """
        raise NotImplementedError()

    def _load_state(self, state_dict):
        """
        Helper function to load the internal state of a model from a dictionary.

        :param state_dict: the dictionary containing the state to load.
        """
        raise NotImplementedError()
예제 #14
0
            # generate scatterplots and add to tensorboard
            fig = plt.figure(figsize=(12, 8))
            plt.scatter(df_tmp['age_at_scan'],
                        df_tmp['predicted_age'],
                        alpha=1)
            tstr = '%s - MAE: %.2f - rho: %.3f (%s)' % (phase, mae,
                                                        correlation, key)
            plt.title(tstr, fontsize=16)
            plt.plot(lims, lims, 'k:')
            plt.grid()
            plt.xlim(lims)
            plt.ylim(lims)
            plt.xlabel('Chronological age', fontsize=16)
            plt.ylabel('Predicted age', fontsize=16)
            writer.add_figure('predictions/' + phase + '_' + key, fig, epoch)
            plt.close()
            fname = os.path.join(results_dir,
                                 'predictions_' + phase + '_' + key + '.csv')
            ratings[phase][key].save_df(fname)

        # calculate ensemble predictions
        df_tmp['predicted_age'] = np.mean(predictions, axis=0)
        fname = os.path.join(results_dir,
                             'predictions_' + phase + '_ensemble.csv')
        df_tmp.to_csv(fname, index=False)
        maes[phase + '_ensemble'] = mean_absolute_error(
            df_tmp['age_at_scan'], df_tmp['predicted_age'])

    writer.add_scalars('mae', maes, epoch)
예제 #15
0
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 1000 == 999:  # every 1000 mini-batches
            # ...log the running loss
            writer.add_scalar('train_loss', running_loss / 1000,
                              epoch * len(trainloader) + i)

            # ...log a Matplotlib Figure showing the model's predictions on a
            # random mini-batch
            writer.add_figure('predictions vs. actual',
                              plot_classes_preds(net, inputs, labels),
                              global_step=epoch * len(trainloader) + i)
            running_loss = 0.0
print('Finished Training')
'''Assess trained models with tensorboard'''
# 1. gets the probability predictions in a test_size x num_classes tensor
# 2. gets the preds in a test_size tensor
# takes ~10 seconds to run
class_probs = []
class_preds = []
with torch.no_grad():
    for data in testloader:
        images, labels = data
        output = net(images)
        class_probs_batch = [F.softmax(el, dim=0) for el in output]
        _, class_preds_batch = torch.max(output, 1)
예제 #16
0
        # gradient descent or adam step
        optimizer.step()

        # visualizing Dataset images
        # img_grid = torchvision.utils.make_grid(images)
        # writer.add_image('Xray_images', img_grid)

        # calculation running accuracy
        model.eval()
        _, predictions = scores.max(1)
        num_correct = (predictions == labels).sum()
        batch_correct_pred += float(num_correct)

        writer.add_figure('predictions vs. actuals',
                           plot_classes_preds(model, images, labels),
                           global_step=step)      #epoch * len(train_loader) + batch_idx
        step += 1

    print(batch_correct_pred)
    epoch_elapsed = (time.time() - epoch_start_time) / 60
    print(f'Epoch {epoch} completed in : {epoch_elapsed:.2f} min')

    batch_loss = sum(losses)/len(losses)
    batch_accuracy = (batch_correct_pred/total_batch_images)*100

    print(f"Cost at epoch {epoch} is {batch_loss}")
    print(f"Training accuracy at {epoch} is: {batch_accuracy:.2f}")
    # batch_accuracy = check_accuracy(train_loader, model)

    if batch_accuracy>best_acc:
예제 #17
0
def train(data_dir, model_dir, args):
    seed_everything(args.seed)

    save_dir = increment_path(os.path.join(model_dir, args.name))
    print(save_dir)

    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # -- dataset
    dataset_module = getattr(import_module("dataset"),
                             args.dataset)  # default: BaseAugmentation
    dataset = dataset_module(data_dir=data_dir, )
    num_classes = dataset.num_classes  # 18

    # -- augmentation
    transform_module = getattr(import_module("dataset"),
                               args.augmentation)  # default: BaseAugmentation
    transform = transform_module(
        resize=args.resize,
        mean=dataset.mean,
        std=dataset.std,
    )
    dataset.set_transform(transform)

    # -- data_loader
    train_set, val_set = dataset.split_dataset()

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        num_workers=8,
        shuffle=True,
        pin_memory=use_cuda,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set,
        batch_size=args.valid_batch_size,
        num_workers=8,
        shuffle=False,
        pin_memory=use_cuda,
        drop_last=True,
    )

    # -- model
    model_module = getattr(import_module("model"),
                           args.model)  # default: BaseModel
    # model = model_module(
    #     num_classes=num_classes
    # ).to(device)
    model = torch.nn.DataParallel(model_module.to(device))

    # -- loss & metric
    criterion = create_criterion(args.criterion)  # default: cross_entropy
    opt_module = getattr(import_module("torch.optim"),
                         args.optimizer)  # default: SGD
    optimizer = opt_module(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        # momentum=args.momentum,
        weight_decay=args.weight_decay)

    # adamp 사용
    # optimizer=AdamP(filter(lambda p: p.requires_grad, model.parameters()),lr=args.lr,weight_decay=args.weight_decay)
    # optimizer = AdamP(
    #     model.parameters(),
    #     lr= 1e-3,
    #     betas=(0.9, 0.999),
    #     eps=1e-8,
    #     weight_decay=0,
    #     delta = 0.1,
    #     wd_ratio = 0.1
    # )

    scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.1)
    # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, eta_min=0.000005) # 로컬 min 빠질 것 작다고 해줘서 변화
    # scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

    # -- logging
    logger = SummaryWriter(log_dir=save_dir)
    with open(os.path.join(save_dir, 'config.json'), 'w',
              encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc = 0
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        # train loop
        model.train()
        loss_value = 0
        matches = 0
        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            loss = criterion(outs, labels)

            loss.backward()
            optimizer.step()

            loss_value += loss.item()
            matches += (preds == labels).sum().item()
            if (idx + 1) % args.log_interval == 0:
                train_loss = loss_value / args.log_interval
                train_acc = matches / args.batch_size / args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
                )
                logger.add_scalar("Train/loss", train_loss,
                                  epoch * len(train_loader) + idx)
                logger.add_scalar("Train/accuracy", train_acc,
                                  epoch * len(train_loader) + idx)

                loss_value = 0
                matches = 0

        scheduler.step()

        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_loss_items = []
            val_acc_items = []
            figure = None
            for val_batch in val_loader:
                inputs, labels = val_batch
                inputs = inputs.to(device)
                labels = labels.to(device)

                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)

                loss_item = criterion(outs, labels).item()
                acc_item = (labels == preds).sum().item()
                val_loss_items.append(loss_item)
                val_acc_items.append(acc_item)

                if figure is None:
                    inputs_np = torch.clone(inputs).detach().cpu().permute(
                        0, 2, 3, 1).numpy()
                    inputs_np = dataset_module.denormalize_image(
                        inputs_np, dataset.mean, dataset.std)
                    figure = grid_image(
                        inputs_np, labels, preds,
                        args.dataset != "MaskSplitByProfileDataset")

            val_loss = np.sum(val_loss_items) / len(val_loader)
            val_acc = np.sum(val_acc_items) / len(val_set)
            best_val_loss = min(best_val_loss, val_loss)
            if val_acc > best_val_acc:
                print(
                    f"New best model for val accuracy : {val_acc:4.2%}! saving the best model.."
                )
                torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
                best_val_acc = val_acc
            torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
                f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
            )
            logger.add_scalar("Val/loss", val_loss, epoch)
            logger.add_scalar("Val/accuracy", val_acc, epoch)
            logger.add_figure("results", figure, epoch)
            print()
예제 #18
0
    torch.save(state, save_path)


if __name__ == "__main__":
    cfg = args.get_argparser('configs/psr_siamdiff_pauli.yml')
    del cfg.test
    torch.backends.cudnn.benchmark = True

    # generate work dir
    run_id = osp.join(
        r'runs', cfg.model.arch + '_' + cfg.train.loss.name + '_' +
        cfg.train.optimizer.name)
    run_id = utils.get_work_dir(run_id)
    writer = SummaryWriter(log_dir=run_id)
    config_fig = types.dict2fig(cfg.to_flatten_dict())
    # plt.savefig(r'./tmp/ff.png')
    writer.add_figure('config', config_fig, close=True)
    # writer.add_hparams(types.flatten_dict_summarWriter(cfg), {'a': 'b'})
    writer.flush()

    # logger
    logger = get_logger(run_id)

    # print('-'*100)
    logger.info(f'RUNDIR: {run_id}')
    logger.info(f'using config file: {cfg.config_file}')
    shutil.copy(cfg.config_file, run_id)

    train(cfg, writer, logger)
    logger.info(f'RUNDIR:{run_id}')
예제 #19
0
# load data
dataloader = DataLoader(NatPatchDataset(arg.batch_size, arg.size, arg.size), batch_size=250)
# train
optim = torch.optim.SGD([{'params': sparse_net.U.weight, "lr": arg.learning_rate}])
for e in range(arg.epoch):
    running_loss = 0
    c = 0
    for img_batch in tqdm(dataloader, desc='training', total=len(dataloader)):
        img_batch = img_batch.reshape(img_batch.shape[0], -1).to(device)
        # update
        pred = sparse_net(img_batch)
        loss = ((img_batch - pred) ** 2).sum()
        running_loss += loss.item()
        loss.backward()
        # update U
        optim.step()
        # zero grad
        sparse_net.zero_grad()
        # norm
        sparse_net.normalize_weights()
        c += 1
    board.add_scalar('Loss', running_loss / c, e * len(dataloader) + c)
    if e % 5 == 4:
        # plotting
        fig = plot_rf(sparse_net.U.weight.T.reshape(arg.n_neuron, arg.size, arg.size).cpu().data.numpy(), arg.n_neuron, arg.size)
        board.add_figure('RF', fig, global_step=e * len(dataloader) + c)
    if e % 10 == 9:
        # save checkpoint
        torch.save(sparse_net, f"../../trained_models/ckpt-{e+1}.pth")
torch.save(sparse_net, f"../../trained_models/ckpt-{e+1}.pth")
예제 #20
0
    def train(self):
        # Load saved model if resume option selected
        if self.resume:
            print(Trainer.time_str() + ' Resuming training ... ')
            checkpoint = torch.load(os.path.join(self.log_root, self.get_epoch_root(self.resume_epoch), 'torch_model_optim.pth'))
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            print(Trainer.time_str() + ' Starting training ... ')

        writer = SummaryWriter(self.log_root)
        self.model = self.model.to(self.device)
        epoch = int(self.model.epoch) + 1
        batch_counter = int(self.model.iteration)
        # Epoch loop
        for epoch in range(epoch, epoch + self.num_epoch):
            # Create logging directory
            epoch_root = self.get_epoch_root(epoch) 
            if not os.path.exists(os.path.join(self.log_root, epoch_root)):
                os.makedirs(os.path.join(self.log_root, epoch_root))
            # Select data loaders for the current epoch
            cur_epoch_loaders = self.get_epoch_loaders(epoch)
            # Dictionary (of dictionaries) to collect four metrics from different phases for tensorboard
            epoch_metric_names = ['epoch_loss', 'epoch_accuracy', 'epoch_precision', 'epoch_recall']
            epoch_metric_dict = {metric_name: dict.fromkeys(cur_epoch_loaders.keys()) for metric_name in epoch_metric_names}
            # Loop over phases within one epoch [train, validation, test]
            for phase in cur_epoch_loaders.keys():
                # Select training state of the NN model
                if phase == 'train':
                    self.model.train(True)
                else:
                    self.model.train(False)

                # Select Loader
                cur_loader = cur_epoch_loaders[phase]
                # Number of samples
                columns = self.target_names.columns
                sample_count_df = pd.DataFrame(np.zeros([2, len(columns)],
                                               dtype=np.int64),
                                               columns=columns,
                                               index=('No', 'Yes'))
                num_samples = len(cur_loader.batch_sampler.sampler)
                total_sample_counter = 0
                num_target_class = self.model.classifier.num_output
                # initializing variables for keeping track of results for tensorboard reporting
                results_phase = self.init_results_phase(num_samples=num_samples, num_target_class=num_target_class)
                for i, data in enumerate(cur_loader):
                    batch_counter += 1
                    # Copy input and targets to the device object
                    inputs = data['input'].to(self.device)
                    type_indices = self.get_target_type_index()
                    targets = data['target'][:, type_indices].float().squeeze().to(self.device)
                    # Zero the parameter gradients
                    self.optimizer.zero_grad()
                    # Forward pass
                    outputs = self.model(inputs).squeeze()
                    loss = self.criterion(outputs, targets)
                    # Backward + Optimize(in training)
                    if phase == 'train':
                        loss.mean().backward()
                        self.optimizer.step()
                    
                    # Record results of the operation for the reporting
                    results_batch = self.get_results_batch(results_phase.keys(), data, loss, outputs)
                    # Aggregate results into a phase array for complete epoch reporting
                    cur_batch_size = inputs.shape[0]
                    nominal_batch_size = cur_loader.batch_size
                    results_phase, batch_idx_range = self.update_results_phase(results_batch=results_batch, 
                                                                               results_phase=results_phase,
                                                                               nominal_batch_size=nominal_batch_size,
                                                                               cur_batch_size=cur_batch_size,
                                                                               batch_idx=i)
                    # Gather number of each class in mini batch
                    total_sample_counter += cur_batch_size
                    non_zero_count = np.count_nonzero(results_batch['target'], axis=0)
                    cur_sample_count = np.vstack((cur_batch_size-non_zero_count, non_zero_count))
                    assert (cur_sample_count.sum(axis=0) == cur_batch_size).all(), 'Sum to batch size check failed'
                    sample_count_df = sample_count_df + cur_sample_count
                    # logging for the running loss and accuracy for each target class
                    if i % self.log_int == 0:
                        running_loss_log = results_phase['loss'][:batch_idx_range[1]].mean(axis=0)
                        running_accuracy = results_phase['correct'][:batch_idx_range[1]].mean(axis=0)
                        accuracy_dict = self.add_target_names(running_accuracy.round(3))
                        running_loss_dict = self.add_target_names(running_loss_log.round(3))
                        print(Trainer.time_str() + ' Phase: ' + phase +
                              f', epoch: {epoch}, batch: {i}, running loss: {running_loss_dict}, running accuracy: {accuracy_dict}')
                        writer.add_scalars(f'running_loss/{phase}', running_loss_dict, batch_counter)
                        writer.add_scalars(f'running_accuracy/{phase}', accuracy_dict, batch_counter)

                # Number of samples in epoch checked two ways
                assert total_sample_counter == num_samples
                # Make sure no -1s left in the phase results (excluding input which throws errors)
                for key in ['loss', 'output_prob', 'prediction', 'target', 'correct']:
                    assert not (results_phase[key] == -1).any()
                # Fraction for each class of target
                class_fraction_df = sample_count_df / num_samples
                assert np.isclose(class_fraction_df.sum(), 1.0).all(), 'All fraction sum to 1.0 failed'
                # the index for positive examples in each class
                with_index = 'Yes'
                fraction_positive_dict = class_fraction_df.loc[with_index].to_dict()
                writer.add_scalars(f'Fraction_with_target/{phase}', fraction_positive_dict, epoch)
                # calculate epoch loss and accuracy average over batch samples
                # Epoch error measures
                epoch_loss_log = results_phase['loss'].mean(axis=0)
                epoch_loss_dict = self.add_target_names(epoch_loss_log.round(3))
                epoch_accuracy_log = results_phase['correct'].mean(axis=0)
                epoch_acc_dict = self.add_target_names(epoch_accuracy_log.round(3))
                print(Trainer.time_str() + ' Phase: ' + phase +
                      f', epoch: {epoch}: epoch loss: {epoch_loss_dict}, epoch accuracy: {epoch_acc_dict}')
                
                # Pickle important results dict elements: loss, output_prob and dataset_indices
                dict_to_save = {key: results_phase[key] for key in ['loss', 'output_prob', 'dataset_indices']}
                io.save_dict(dict_to_save, os.path.join(self.log_root, epoch_root, 'results_saved.pkl'))

                # Precision, recall, accuracy and loss 
                precision, recall, _, num_pos = sk_metrics.precision_recall_fscore_support(results_phase['target'].squeeze(),
                                                                                           results_phase['prediction'].squeeze(),
                                                                                           zero_division=0)
               
                # The metrics function returns the result for both positive and negative labels when operated with
                # a single target type. When the task is a multilabel decision it only returns the positive label results
                if num_target_class == 1:
                    precision = [precision[1]]
                    recall = [recall[1]]
                    num_pos = num_pos[1]
                assert (np.asarray(sample_count_df.loc['Yes']) == num_pos).all(), 'Number of positive samples matching failed'
                cur_metrics = [epoch_loss_dict, epoch_acc_dict,
                               self.add_target_names(precision), self.add_target_names(recall)]
                for i, metric_name in enumerate(epoch_metric_names):
                    epoch_metric_dict[metric_name][phase] = cur_metrics[i]
                
                # Confusion matrix Figure
                if num_target_class == 1:
                    confusion_matrix = sk_metrics.confusion_matrix(results_phase['target'].squeeze(),
                                                                   results_phase['prediction'].squeeze())
                elif num_target_class > 1:
                    confusion_matrix = sk_metrics.multilabel_confusion_matrix(results_phase['target'],
                                                                              results_phase['prediction'])
                else:
                    raise Exception('number of target classes is negative')
                fig_confusion_norm = self.plot_confusion_matrix(confusion_matrix)
                figname_confusion = 'Confusion_matrix'
                fig_confusion_norm.savefig(os.path.join(self.log_root,
                                                        epoch_root,
                                                        figname_confusion + phase + '.png'),
                                           dpi=300)
                writer.add_figure(f'{figname_confusion}/{phase}', fig_confusion_norm, epoch)

                # Images with highest loss in each target type (Myelin and artefact currently)
                fig = self.show_imgs(results_phase=results_phase)
                figname_examples = 'Examples_with_highest_loss'
                fig.savefig(os.path.join(self.log_root, epoch_root, figname_examples + '_' + phase + '.png'), dpi=300)
                writer.add_figure(f'{figname_examples}/{phase}', fig, epoch)

                # Precision/Recall curves
                for i, t_type in enumerate(self.target_names):
                    writer.add_pr_curve(f'{t_type}/{phase}',
                                        labels=results_phase.get('target')[:, i],
                                        predictions=results_phase.get('output_prob')[:, i],
                                        global_step=epoch,
                                        num_thresholds=100)

                # save model
                if self.save & (phase == 'train') & (epoch % self.save_int == 0):
                    print(Trainer.time_str() + ' Writing model graph ... ')
                    # writer.add_graph(self.model, inputs)
                    print(Trainer.time_str() + ' Saving model state... ')
                    self.model.epoch = torch.nn.Parameter(torch.tensor(epoch), requires_grad=False)
                    self.model.iteration = torch.nn.Parameter(torch.tensor(batch_counter), requires_grad=False)
                    torch.save({
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict()
                    }, os.path.join(self.log_root, epoch_root, 'torch_model_optim.pth'))

            # write the epoch related metrics to the tensorboard
            for metric_name in epoch_metric_names:
                cur_metric = epoch_metric_dict[metric_name]
                for ph in cur_metric:
                    cur_metric_phase = {f'{ph}_{t_type}': val for t_type, val in cur_metric[ph].items()}
                    writer.add_scalars(metric_name, cur_metric_phase, epoch)
        print(Trainer.time_str() + ' Finished training ... ')
        writer.close()
        print(Trainer.time_str() + ' Closed writer ... ')
예제 #21
0
class BaseTrainer:
    def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer):
        self.color_tool = colorful
        self.color_tool.use_style("solarized")

        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustics"]

        # Supported STFT
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]

        self.torch_stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.rank)
        self.torch_istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.rank)
        self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length)

        # Trainer.train in the config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config["save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one."

        # Trainer.validation in the config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config["validation_interval"]
        self.save_max_metric_score = self.validation_config["save_max_metric_score"]
        assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one."

        # Trainer.visualization in the config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute() / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        # Debug validation, which skips training
        self.only_validation = only_validation

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

            self.writer = SummaryWriter(self.logs_dir.as_posix(), max_queue=5, flush_secs=30)
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1
            )

            print(self.color_tool.cyan("The configurations are as follows: "))
            print(self.color_tool.cyan("=" * 40))
            print(self.color_tool.cyan(toml.dumps(config)[:-1]))  # except "\n"
            print(self.color_tool.cyan("=" * 40))

            with open((self.save_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])

    def _preload_model(self, model_path):
        """
        Preload model parameters (in "*.tar" format) at the start of experiment.

        Args:
            model_path (Path): The file path of the *.tar file
        """
        model_path = model_path.expanduser().absolute()
        assert model_path.exists(), f"The file {model_path.as_posix()} is not exist. please check path."

        map_location = {'cuda:%d' % 0: 'cuda:%d' % self.rank}
        model_checkpoint = torch.load(model_path.as_posix(), map_location=map_location)
        self.model.load_state_dict(model_checkpoint["model"], strict=False)

        if self.rank == 0:
            print(f"Model preloaded successfully from {model_path.as_posix()}.")

    def _resume_checkpoint(self):
        """
        Resume the experiment from the latest checkpoint.
        """
        latest_model_path = self.checkpoints_dir.expanduser().absolute() / "latest_model.tar"
        assert latest_model_path.exists(), f"{latest_model_path} does not exist, can not load latest checkpoint."

        self.dist.barrier()  # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work

        map_location = {'cuda:%d' % 0: 'cuda:%d' % self.rank}
        checkpoint = torch.load(latest_model_path.as_posix(), map_location=map_location)

        self.start_epoch = checkpoint["epoch"] + 1
        self.best_score = checkpoint["best_score"]
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scaler.load_state_dict(checkpoint["scaler"])
        self.model.load_state_dict(checkpoint["model"])

        if self.rank == 0:
            print(f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch.")

    def _save_checkpoint(self, epoch, is_best_epoch=False):
        """
        Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of:
            - the epoch number
            - the best metric score in history
            - the optimizer parameters
            - the model parameters

        Args:
            is_best_epoch (bool): In current epoch, if the model get a best metric score (is_best_epoch=True),
                                the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar".
        """
        print(f"\t Saving {epoch} epoch model checkpoint...")

        # TODO
        # 统一训练与推理时的处理方式:"module.*"...
        state_dict = {
            "epoch": epoch,
            "best_score": self.best_score,
            "optimizer": self.optimizer.state_dict(),
            "scaler": self.scaler.state_dict(),
            "model": self.model.state_dict()
        }

        # Saved in "latest_model.tar"
        # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc.
        # New checkpoint will overwrite the older one.
        torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix())

        # "model_{epoch_number}.tar"
        # Contains all checkpoint information, like "latest_model.tar". However, the newer information will no overwrite the older one.
        torch.save(state_dict, (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.tar").as_posix())

        # If the model get a best metric score (means "is_best_epoch=True") in the current epoch,
        # the model checkpoint will be saved as "best_model.tar"
        # The newer best-scored checkpoint will overwrite the older one.
        if is_best_epoch:
            print(self.color_tool.red(f"\t Found a best score in the {epoch} epoch, saving..."))
            torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix())

    def _is_best_epoch(self, score, save_max_metric_score=True):
        """
        Check if the current model got the best metric score
        """
        if save_max_metric_score and score >= self.best_score:
            self.best_score = score
            return True
        elif not save_max_metric_score and score <= self.best_score:
            self.best_score = score
            return True
        else:
            return False

    @staticmethod
    def _print_networks(models: list):
        print(f"This project contains {len(models)} models, the number of the parameters is: ")

        params_of_all_networks = 0
        for idx, model in enumerate(models, start=1):
            params_of_network = 0
            for param in model.parameters():
                params_of_network += param.numel()

            print(f"\tNetwork {idx}: {params_of_network / 1e6} million.")
            params_of_all_networks += params_of_network

        print(f"The amount of parameters in the project is {params_of_all_networks / 1e6} million.")

    def _set_models_to_train_mode(self):
        self.model.train()

    def _set_models_to_eval_mode(self):
        self.model.eval()

    def spec_audio_visualization(self, noisy, enhanced, clean, name, epoch, mark=""):
        self.writer.add_audio(f"{mark}_Speech/{name}_Noisy", noisy, epoch, sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Clean", clean, epoch, sample_rate=16000)

        # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech
        noisy_mag, _ = librosa.magphase(self.librosa_stft(noisy, n_fft=320, hop_length=160, win_length=320))
        enhanced_mag, _ = librosa.magphase(self.librosa_stft(enhanced, n_fft=320, hop_length=160, win_length=320))
        clean_mag, _ = librosa.magphase(self.librosa_stft(clean, n_fft=320, hop_length=160, win_length=320))
        fig, axes = plt.subplots(3, 1, figsize=(6, 6))
        for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]):
            axes[k].set_title(
                f"mean: {np.mean(mag):.3f}, "
                f"std: {np.std(mag):.3f}, "
                f"max: {np.max(mag):.3f}, "
                f"min: {np.min(mag):.3f}"
            )
            librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000)
        plt.tight_layout()
        self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch)

    def metrics_visualization(self, noisy_list, clean_list, enhanced_list, metrics_list, epoch, num_workers=10, mark=""):
        """
        Get metrics on validation dataset by paralleling.

        Notes:
            1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are
             used for checking if the current epoch is a "best epoch."
            2. If you want to use a new metric, you must register it in "util.metrics" file.
        """
        assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence."

        # Check if the metric is registered in "util.metrics" file.
        for i in metrics_list:
            assert i in metrics.REGISTERED_METRICS.keys(), f"{i} is not registered, please check 'util.metrics' file."

        stoi_mean = 0.0
        wb_pesq_mean = 0.0
        for metric_name in metrics_list:
            score_on_noisy = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, noisy_list)
            )
            score_on_enhanced = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, enhanced_list)
            )

            # Add the mean value of the metric to tensorboard
            mean_score_on_noisy = np.mean(score_on_noisy)
            mean_score_on_enhanced = np.mean(score_on_enhanced)
            self.writer.add_scalars(f"{mark}_Validation/{metric_name}", {
                "Noisy": mean_score_on_noisy,
                "Enhanced": mean_score_on_enhanced
            }, epoch)

            if metric_name == "STOI":
                stoi_mean = mean_score_on_enhanced

            if metric_name == "WB_PESQ":
                wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced)

        return (stoi_mean + wb_pesq_mean) / 2

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            if self.rank == 0:
                print(self.color_tool.yellow(f"{'=' * 15} {epoch} epoch {'=' * 15}"))
                print("[0 seconds] Begin training...")

            # [debug validation] Only run validation (only use the first GPU (process))
            # inference + calculating metrics + saving checkpoints
            if self.only_validation and self.rank == 0:
                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(metric_score, save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

                # Skip the following regular training, saving checkpoints, and validation
                continue

            # Regular training
            timer = ExecutionTime()
            self._set_models_to_train_mode()
            self._train_epoch(epoch)

            #  Regular save checkpoints
            if self.rank == 0 and self.save_checkpoint_interval != 0 and (epoch % self.save_checkpoint_interval == 0):
                self._save_checkpoint(epoch)

            # Regular validation
            if self.rank == 0 and (epoch % self.validation_interval == 0):
                print(f"[{timer.duration()} seconds] Training has finished, validation is in progress...")

                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(metric_score, save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

            print(f"[{timer.duration()} seconds] This epoch is finished.")

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _validation_epoch(self, epoch):
        raise NotImplementedError
예제 #22
0
class Saver(object):
    """
    Saver allows for saving and restore networks.
    """
    def __init__(self,
                 base_output_dir: Path,
                 args: dict,
                 sub_dirs=('train', 'test'),
                 tag=''):

        # Store args
        self.args = args
        # Create experiment directory
        timestamp_str = datetime.fromtimestamp(
            time()).strftime('%Y_%m_%d_%H_%M_%S')
        if isinstance(tag, str) and len(tag) > 0:
            # Append tag
            timestamp_str += f"_{tag}"
        self.path = base_output_dir / f'{timestamp_str}'
        self.path.mkdir(parents=True, exist_ok=True)

        # Setup loggers
        self.tb = None
        self.cml = None
        if has_tb:
            self.tb = SummaryWriter(str(self.path))
        if has_cml and args.cometml_api_key_path is not None and \
                       args.cometml_workspace is not None and \
                       args.cometml_project is not None:
            # Read API key
            with open(args.cometml_api_key_path, 'r') as file:
                api_key = file.read().strip()
            # Read project and workspace
            cometml_project = args.cometml_project.strip()
            cometml_workspace = args.cometml_workspace.strip()
            # Create experiment
            self.cml = comet_ml.Experiment(api_key=api_key,
                                           project_name=cometml_project,
                                           workspace=cometml_workspace,
                                           parse_args=False)
            self.cml.set_name(tag)
        # Warnings
        if self.tb is None and self.cml is None:
            print('Saver: warning: no logger')
        else:
            if self.tb is not None:
                print('Saver: using TensorBoard')
            if self.cml is not None:
                print('Saver: using CometML')
        # Create checkpoint sub-directory
        self.ckpt_path = self.path / 'ckpt'
        self.ckpt_path.mkdir(parents=True, exist_ok=True)
        # Create output sub-directories
        self.sub_dirs = sub_dirs
        self.output_path = {}
        for s in self.sub_dirs:
            self.output_path[s] = self.path / 'output' / s
        for d in self.output_path.values():
            d.mkdir(parents=True, exist_ok=False)
        # Dump experiment hyper-params
        with open(self.path / 'hyperparams.txt', mode='wt') as f:
            args_str = [f'{a}: {v}\n' for a, v in self.args.__dict__.items()]
            args_str.append(f'exp_name: {timestamp_str}\n')
            f.writelines(sorted(args_str))
        # Dump command
        with open(self.path / 'command.txt', mode='wt') as f:
            cmd_args = ' '.join(sys.argv)
            f.write(cmd_args)
            f.write('\n')
        # Dump the `git log` and `git diff`. In this way one can checkout
        #  the last commit, add the diff and should be in the same state.
        for cmd in ['log', 'diff']:
            with open(self.path / f'git_{cmd}.txt', mode='wt') as f:
                subprocess.run(['git', cmd], stdout=f)

    def save_data(self, data, name: str):
        """
        Save generic data
        """
        torch.save(data, self.path / f'{name}.pth')

    def save_model(self, net: torch.nn.Module, name: str, epoch: int):
        """
        Save model parameters in the checkpoint directory.
        """
        # Get state dict
        state_dict = net.state_dict()
        # Copy to CPU
        for k, v in state_dict.items():
            state_dict[k] = v.cpu()
        # Save
        torch.save(state_dict, self.ckpt_path / f'{name}_{epoch:05d}.pth')

    def add_graph(self, model, images):
        if self.tb is not None:
            self.tb.add_graph(model, images)

    def dump_batch_image(self, image: torch.FloatTensor, epoch: int,
                         split: str, name: str):
        """
        Dump image batch into folder (as grid) and tb
        TODO: something's wrong with the whole BGR2RGB stuff, we shouldn't need it
        """
        assert split in self.sub_dirs
        assert len(image.shape
                   ) == 4, f'shape {image.shape} differs from BxCxHxW format'
        assert image.min() >= 0 and image.max(
        ) <= 1, 'image must be between 0 and 1!'

        out_image_path = self.output_path[split] / f'{epoch:05d}_{name}.jpg'
        image_rolled = torchvision.utils.make_grid(
            image.cpu(), nrow=8,
            pad_value=1)  #, normalize=True, scale_each=True)
        # Save image file
        TF.to_pil_image(image_rolled).save(out_image_path)
        # TensorBoard
        if self.tb is not None:
            self.tb.add_image(f'{split}/{name}', image_rolled, epoch)
        # CometML
        if self.cml is not None:
            self.cml.log_image(image_rolled,
                               name=f'{split}/{name}',
                               step=epoch,
                               image_channels='first')

    def dump_batch_video(self, video: torch.FloatTensor, epoch: int,
                         split: str, name: str):
        """
        Dump video batch into folder (as grid) and tb
        FIXME: not sure this works
        """
        assert split in self.sub_dirs
        assert len(video.shape
                   ) == 5, f'shape {video.shape} differs from BxTxCxHxW format'
        assert video.min() >= 0 and video.max(
        ) <= 1, 'video must be between 0 and 1!'
        out_image_path = self.output_path[split] / f'{epoch:05d}_{name}.jpg'
        video_rolled = video_tensor_to_grid(video, return_image=True)
        cv2.imwrite(str(out_image_path), np.transpose(video_rolled, (1, 2, 0)))
        if self.tb is not None:
            self.tb.add_video(name, video, epoch, fps=5)

    def dump_line(self, line, step, split, name=None, fmt=None, labels=None):
        """
        Dump line as matplotlib figure into folder and tb
        TODO: test CometML
        """
        # Line data
        fig = plt.figure()
        if isinstance(line, tuple):
            line_x, line_y = line
            line_x = line_x.cpu().detach()
            line_y = line_y.cpy().detach()
        else:
            line_x = torch.arange(line.numel())
            line_y = line.cpu().detach()
        # kwargs
        kwargs = {}
        if fmt is not None: kwargs['fmt'] = fmt
        # Plot
        plt.plot(line_x, line_y, **kwargs)
        # Ticks
        if labels is not None:
            pass
            #plt.xticks(line_x, labels, rotation='vertical', fontsize=4)
            #plt.margins(0.9)
            #plt.subplots_adjust(bottom=0.8)
        # Save
        if name is not None:
            assert split in self.sub_dirs
            out_path = self.output_path[
                split] / f'line_{step:08d}_{name.replace("/", "_")}.jpg'
            plt.savefig(out_path)
        if self.tb is not None:
            self.tb.add_figure(
                f'{split}/{name}' if name is not None else split, fig, step)
        if self.cml is not None:
            self.cml.log_figure(
                figure_name=f'{split}/{name}' if name is not None else split,
                figure=fig,
                step=step)

    def dump_histogram(self, tensor: torch.Tensor, epoch: int, desc: str):
        """
        TODO: disabled for CometML, too slow
        """
        values = tensor.contiguous().view(-1)
        if self.tb is not None:
            #try:
            self.tb.add_histogram(desc, values, epoch)
            #except:
            #print('Error writing histogram')
        #if self.cml is not None:
        #    self.cml.log_histogram_3d(values, desc, epoch)

    def dump_metric(self, value: float, epoch: int, *tags):
        if self.tb is not None:
            self.tb.add_scalar('/'.join(tags), value, epoch)
        if self.cml is not None:
            self.cml.log_metric('/'.join(tags), value, step=epoch)

    @staticmethod
    def load_hyperparams(hyperparams_path):
        """
        Load hyperparams from file. Tries to convert them to best type.
        """
        # Process input
        hyperparams_path = Path(hyperparams_path)
        if not hyperparams_path.exists():
            raise OSError('Please provide a valid path')
        if hyperparams_path.is_dir():
            hyperparams_path = os.path.join(hyperparams_path,
                                            'hyperparams.txt')
        # Prepare output
        output = {}
        # Read file
        with open(hyperparams_path) as file:
            # Read lines
            for l in file:
                # Remove new line
                l = l.strip()
                # Separate name from value
                toks = l.split(':')
                name = toks[0]
                value = ':'.join(toks[1:]).strip()
                # Parse value
                try:
                    value = ast.literal_eval(value)
                except:
                    pass
                # Add to output
                output[name] = value
        # Return
        return output

    @staticmethod
    def load_state_dict(model_path: Union[str, Path], verbose: bool = True):
        """
        Load state dict from pre-trained checkpoint. In case a directory is
          given as `model_path`, the last modified checkpoint is loaded.
        """
        model_path = Path(model_path)
        if not model_path.exists():
            raise OSError('Please provide a valid path for restoring weights.')

        if model_path.is_dir():
            # Check there are files in that directory
            file_list = sorted(model_path.glob('*.pth'), key=getmtime)
            if len(file_list) == 0:
                # Check there are files in the 'ckpt' subdirectory
                model_path = model_path / 'ckpt'
                file_list = sorted(model_path.glob('*.pth'), key=getmtime)
                if len(file_list) == 0:
                    raise OSError("Couldn't find pth file.")
            checkpoint = file_list[-1]
        elif model_path.is_file():
            checkpoint = model_path

        if verbose:
            print(f'Loading pre-trained weight from {checkpoint}...')

        return torch.load(checkpoint)

    def close(self):
        if self.tb is not None:
            self.tb.close()
예제 #23
0
class Logger:
    def __init__(self,
                 logdir,
                 rank,
                 type='torch',
                 debug=False,
                 filename=None,
                 summary=True,
                 step=None):
        self.logger = None
        self.type = type
        self.rank = rank
        self.step = step
        self.logdir_results = os.path.join("logs", "results")
        self.summary = summary
        if summary:
            if type == 'tensorboardX':
                import tensorboardX
                self.logger = tensorboardX.SummaryWriter(logdir)
            elif type == "torch":
                from torch.utils.tensorboard import SummaryWriter
                self.logger = SummaryWriter(logdir)
            else:
                raise NotImplementedError
        else:
            self.type = 'None'

        self.debug_flag = debug
        logging.basicConfig(filename=filename,
                            level=logging.INFO,
                            format=f'%(levelname)s:rank{rank}: %(message)s')

        if rank == 0:
            os.makedirs(self.logdir_results, exist_ok=True)
            logging.info(f"[!] starting logging at directory {logdir}")
            if self.debug_flag:
                logging.info(f"[!] Entering DEBUG mode")

    def close(self):
        if self.logger is not None:
            self.logger.close()
        self.info("Closing the Logger.")

    def add_scalar(self, tag, scalar_value, step=None):
        if self.is_not_none():
            tag = self._transform_tag(tag)
            self.logger.add_scalar(tag, scalar_value, step)

    def add_image(self, tag, image, step=None):
        if self.is_not_none():
            tag = self._transform_tag(tag)
            self.logger.add_image(tag, image, step)

    def add_figure(self, tag, image, step=None):
        if self.is_not_none():
            tag = self._transform_tag(tag)
            self.logger.add_figure(tag, image, step)

    def add_table(self, tag, tbl, step=None):
        if self.is_not_none():
            tag = self._transform_tag(tag)
            tbl_str = "<table width=\"100%\"> "
            tbl_str += "<tr> \
                     <th>Term</th> \
                     <th>Value</th> \
                     </tr>"

            for k, v in tbl.items():
                tbl_str += "<tr> \
                           <td>%s</td> \
                           <td>%s</td> \
                           </tr>" % (k, v)

            tbl_str += "</table>"
            self.logger.add_text(tag, tbl_str, step)

    def add_results(self, results, tag="Results"):
        if self.is_not_none():
            tag = self._transform_tag(tag)
            text = "<table width=\"100%\">"
            for k, res in results.items():
                text += f"<tr><td>{k}</td>" + " ".join(
                    [str(f'<td>{x}</td>') for x in res.values()]) + "</tr>"
            text += "</table>"
            self.logger.add_text(tag, text)

    def print(self, msg):
        logging.info(msg)

    def info(self, msg):
        if self.rank == 0:
            logging.info(msg)

    def debug(self, msg):
        if self.rank == 0 and self.debug_flag:
            logging.info(msg)

    def error(self, msg):
        logging.error(msg)

    def log_results(self, task, name, results, novel=False):
        if self.rank == 0:
            file_name = f"{task.task}-n{task.nshot}.csv" if task.nshot != -1 else f"{task.task}-0"
            file_name = file_name if not novel else f"{file_name}_novel.csv"
            dir_path = f"{self.logdir_results}/{task.dataset}"
            path = f"{dir_path}/{file_name}"
            if not os.path.exists(dir_path):
                os.makedirs(dir_path, exist_ok=True)
            text = [
                str(round(time.time())), name,
                str(self.step),
                str(task.nshot),
                str(task.ishot)
            ]
            for val in results:
                text.append(str(val))
            row = ",".join(text) + "\n"
            with open(path, "a") as file:
                file.write(row)

    def log_aggregates(self, task, name, results):
        if self.rank == 0:
            file_name = f"{task.task}-n{task.nshot}-agg.csv" if task.nshot != -1 else f"{task.task}-0-agg.csv"
            dir_path = f"{self.logdir_results}/{task.dataset}"
            path = f"{dir_path}/{file_name}"
            if not os.path.exists(dir_path):
                os.makedirs(dir_path, exist_ok=True)
            text = [
                str(round(time.time())), name,
                str(self.step),
                str(task.nshot),
                str(task.ishot)
            ]
            for val in results:
                text.append(str(val))
            row = ",".join(text) + "\n"
            with open(path, "a") as file:
                file.write(row)

    def _transform_tag(self, tag):
        tag = tag + f"/{self.step}" if self.step is not None else tag
        return tag

    def is_not_none(self):
        return self.type != "None"
예제 #24
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt['path']['log'],  opt['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            
            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)
            
            #### update learning rate
            model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in ['sr', 'srgan'] and rank <= 0:  # image restoration validation
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(img_dir,
                                                     '{:s}_{:d}.png'.format(img_name, current_step))
                        if opt['save_img']:
                            util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)
                        tb_logger.add_figure('output', sr_img, current_step)
                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32,
                                                               device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt['name']:
                                tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'].item()
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        tb_logger.close()
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='PRETRAINED EMOTION CLASSIFICATION')
    parser.add_argument('--batch-size',
                        type=int,
                        metavar='N',
                        help='input batch size for training')
    parser.add_argument('--dataset-dir',
                        default='data',
                        help='directory that contains cifar-10-batches-py/ '
                        '(downloaded automatically if necessary)')
    parser.add_argument('--epochs',
                        type=int,
                        metavar='N',
                        help='number of epochs to train')
    parser.add_argument('--log-interval',
                        type=int,
                        default=75,
                        metavar='N',
                        help='number of batches between logging train status')
    parser.add_argument('--lr', type=float, metavar='LR', help='learning rate')
    parser.add_argument('--model-name',
                        type=str,
                        default='run-01',
                        help='saves the current model')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0,
                        help='Weight decay hyperparameter')
    parser.add_argument('--continue-train',
                        type=str,
                        default='NONE',
                        help='saves the current model')
    parser.add_argument('--examine', default=False, action='store_true')
    parser.add_argument('--visualize', default=False, action='store_true')
    args = parser.parse_args()
    # set seed

    SEED = 1234

    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    train_imgs_dir = os.path.join(args.dataset_dir, "train")
    train_labels = pd.read_csv(
        os.path.join(args.dataset_dir, "label_eliminate_2/train_labels.csv"))

    val_imgs_dir = os.path.join(args.dataset_dir, "test")
    val_labels = pd.read_csv(
        os.path.join(args.dataset_dir, "label_eliminate_2/test_labels.csv"))

    #test_imgs_dir = os.path.join(args.dataset_dir, "test")
    #test_labels = pd.read_csv(os.path.join(args.dataset_dir, "label/test_label.csv"))

    training_data_transform = T.Compose([
        #T.ToPILImage("RGB"),
        #T.RandomRotation(5),
        T.RandomHorizontalFlip(0.5),
        # SquarePad(),
        T.Resize((128, 128)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    test_data_transform = T.Compose([
        #T.ToPILImage("RGB"),
        # SquarePad(),
        T.Resize((128, 128)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    train_sample = np.random.choice(range(249800), 40000, replace=False)
    train_set = PretrainImageDataset(train_labels,
                                     train_imgs_dir,
                                     transform=training_data_transform)
    val_sample = np.random.choice(range(15770), 4000, replace=False)
    val_set = PretrainImageDataset(val_labels,
                                   val_imgs_dir,
                                   transform=test_data_transform)
    #test_set = ImageDataset(test_labels, test_imgs_dir, transform=test_data_transform)

    #test dataset

    #print("testset: ",len(test_set))

    train_dataloader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  sampler=train_sample)
    val_dataloader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                sampler=val_sample)
    #test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True)
    print("trainset: ", len(train_dataloader))
    print("val: ", len(val_dataloader))

    test_dataloader = None

    # Load CIFAR10 dataset
    if (args.visualize == True):
        writer = SummaryWriter('runs_pretrained/' + args.model_name)
        # plot the images in the batch, along with predicted and true labels
        for i in range(30):
            fig = plt.figure(figsize=(12, 48))

            image1 = train_set[i]['image_1']
            image2 = train_set[i]['image_2']
            image3 = train_set[i]['image_3']
            images = [image1, image2, image3]
            label_A = train_set[i]['label_A']
            label_B = train_set[i]['label_B']

            for idx in np.arange(3):
                ax = fig.add_subplot(1, 3, idx + 1, xticks=[], yticks=[])
                matplotlib_imshow(images[idx], one_channel=False)
                ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
                    "percentage", label_B[idx], classes[label_A]))
            writer.add_figure('predictions vs. actuals', fig, global_step=i)

    elif (args.examine == True):
        model = Siamese()
        model.load_state_dict(
            torch.load('runs_pretrained/' + args.model_name + '/' +
                       args.model_name + '.pth'))
        model.to(device)

        images, labels, probs = get_predictions(args, model, test_dataloader,
                                                device)
        pred_labels = torch.argmax(probs, 1)
        cm = confusion_matrix(labels, pred_labels)
        #plot_confusion_matrix(args, labels, pred_labels)
        plot_confusion_matrix(args,
                              cm,
                              l_classes=np.asarray(classes),
                              normalize=True,
                              title='Normalized confusion matrix')
        print("done!")
    else:

        writer = SummaryWriter('runs_pretrained/' + args.model_name)

        if (args.continue_train == "NONE"):
            model = Siamese()
            model.apply(initialize_parameters)

        else:

            model = Siamese()
            model.load_state_dict(
                torch.load('runs_pretrained/' + args.continue_train + '/' +
                           args.continue_train + '.pth'))
            print("CONTINUE TRAIN MODE----")

        def count_parameters(model):
            return sum(p.numel() for p in model.parameters()
                       if p.requires_grad)

        print(
            f'The model has {count_parameters(model):,} trainable parameters')

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
        Loss_B = Multi_cross_entropy()
        criterion = [nn.CrossEntropyLoss(), Loss_B]
        model.to(device)
        #criterion = criterion.to(device)
        model.train()
        optimizer.zero_grad()

        # Define optimizer
        # opt = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        # Record loss and accuracy history
        args.train_loss = []
        args.val_loss = []
        args.val_acc = []

        # Train the model
        best_valid_loss = float('inf')

        for epoch in range(1, args.epochs + 1):
            start_time = time.monotonic()
            best_valid_loss = train(args, epoch, model, train_dataloader,
                                    val_dataloader, optimizer, criterion,
                                    device, writer, best_valid_loss)
            end_time = time.monotonic()

            epoch_mins, epoch_secs = epoch_time(start_time, end_time)
            print(
                f'Epoch: {epoch :02} | Epoch Time: {epoch_mins}m {epoch_secs}s'
            )

        # Evaluate on test set
        writer.flush()
        """
예제 #26
0
def main(argv):
    # ==============================
    # SET-UP
    # ==============================
    # create summary writer, publisher
    Path(opt.logs).mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(os.path.join(opt.logs,f'lr{opt.lr}_{time.time()}'))
    # load dataset
    data_loader = CarLoader(opt)
    dataset = data_loader.load_train()
    # load network, and publish
    model = Classifier(opt)
    # load from checkpoint?
    if opt.load_checkpoint:
        model.load_state_dict(torch.load(opt.load_checkpoint))
    writer.add_graph(model.model, torch.randn(1,3,224,224))
    # load loss function
    criterion = torch.nn.CrossEntropyLoss()
    # move to gpu
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        model.cuda(); criterion.cuda();
    # load optimizer
    optimizer = torch.optim.Adam(model.model.parameters(),
                                 lr=opt.lr,
                                 betas=(opt.beta1, opt.beta2))
    # ==============================
    # TRAINING
    # ==============================
    for epoch in range(opt.max_epochs):
        # epoch set-up
        since = time.time()
        total = 0
        current_loss = 0.0
        accuracy = 0.0
        # iterate through dataset
        for i, (imgs, labels) in enumerate(iter(dataset)):
            # move to avalible gpu
            imgs = imgs.to(device); labels = labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward pass
            predictions = model(imgs)
            loss = criterion(predictions, labels)
            # update weights through backpropagation
            loss.backward()
            optimizer.step()
            # keep records
            total += labels.size(0)
            current_loss += loss.item()
            _, preds = torch.max(predictions.data, 1)
            accuracy += (preds == labels).sum().item()
            print(f'\rprocessing: {i}/{len(dataset)}, loss: {loss.item(): >10}', end='')
        # save checkpoint, create dir if not already there
        if epoch % opt.save_freq == 0 or epoch == opt.max_epochs-1:
            Path(opt.checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(model.model.state_dict(), os.path.join(opt.checkpoint, f'e{epoch}_l{loss}'))
        # export summary
        current_loss = current_loss/total
        accuracy = accuracy/total
        # publish to tensorflow
        if epoch % opt.summary_freq == 0 or epoch == opt.max_epochs-1:
            writer.add_scalar('loss', current_loss, epoch)
            writer.add_scalar('accuracy', accuracy, epoch)
            # publish grid with predictions
            fig = _plot_grid(predictions, imgs, labels, data_loader.classes)
            writer.add_figure('predictions', fig, global_step=epoch)
            for name, param in model.model.named_parameters():
                writer.add_histogram(f'weights_' + name, param.data.cpu().numpy(), epoch)
                writer.add_histogram(f'grad_' + name, param.grad.data.cpu().numpy(), epoch)
        # print summary
        print('\repoch: {e:>6}, loss: {loss:.4f}, accuracy: {acc:.4f}, in: {time:.4f}s'\
                .format(e=epoch,
                        loss = current_loss,
                        acc = accuracy,
                        time = time.time()-since))
    writer.close()
예제 #27
0
파일: train.py 프로젝트: Hiroshiba/hifi-gan
def train(rank, a, h):
    if h.num_gpus > 1:
        init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
                           world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)

    torch.cuda.manual_seed(h.seed)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(h).to(device)
    mpd = MultiPeriodDiscriminator().to(device)
    msd = MultiScaleDiscriminator().to(device)

    if rank == 0:
        print(generator)
        os.makedirs(a.checkpoint_path, exist_ok=True)
        print("checkpoints directory : ", a.checkpoint_path)

    if os.path.isdir(a.checkpoint_path):
        cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
        cp_do = scan_checkpoint(a.checkpoint_path, 'do_')

    steps = 0
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        mpd.load_state_dict(state_dict_do['mpd'])
        msd.load_state_dict(state_dict_do['msd'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if h.num_gpus > 1:
        generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
        mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
    optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
                                h.learning_rate, betas=[h.adam_b1, h.adam_b2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(a)

    trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
                          h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
                          shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
                          fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)

    train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None

    train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
                              sampler=train_sampler,
                              batch_size=h.batch_size,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
                              h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
                              fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
                              base_mels_path=a.input_mels_dir)
        validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))

    generator.train()
    mpd.train()
    msd.train()
    for epoch in range(max(0, last_epoch), a.training_epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch+1))

        if h.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            x, y, _, y_mel = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
            y = y.unsqueeze(1)

            y_g_hat = generator(x)
            y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
                                          h.fmin, h.fmax_for_loss)

            optim_d.zero_grad()

            # MPD
            y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
            loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)

            # MSD
            y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
            loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)

            loss_disc_all = loss_disc_s + loss_disc_f

            loss_disc_all.backward()
            optim_d.step()

            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45

            y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
            y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
            loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
            loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
            loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
            loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
            loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % a.stdout_interval == 0:
                    with torch.no_grad():
                        mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()

                    print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
                          format(steps, loss_gen_all, mel_error, time.time() - start_b))

                # checkpointing
                if steps % a.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
                    save_checkpoint(checkpoint_path,
                                    {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
                    checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
                    save_checkpoint(checkpoint_path, 
                                    {'mpd': (mpd.module if h.num_gpus > 1
                                                         else mpd).state_dict(),
                                     'msd': (msd.module if h.num_gpus > 1
                                                         else msd).state_dict(),
                                     'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
                                     'epoch': epoch})

                # Tensorboard summary logging
                if steps % a.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
                    sw.add_scalar("training/mel_spec_error", mel_error, steps)

                # Validation
                if steps % a.validation_interval == 0:  # and steps != 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, batch in enumerate(validation_loader):
                            x, y, _, y_mel = batch
                            y_g_hat = generator(x.to(device))
                            y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
                            y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
                                                          h.hop_size, h.win_size,
                                                          h.fmin, h.fmax_for_loss)
                            if y_mel.shape[2] != y_g_hat_mel.shape[2]:
                                assert abs(y_mel.shape[2] - y_g_hat_mel.shape[2]) < 5
                                min_length = min(y_mel.shape[2], y_g_hat_mel.shape[2])
                                y_mel = y_mel[:, :, :min_length]
                                y_g_hat_mel = y_g_hat_mel[:, :, :min_length]
                            val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()

                            if j <= 4:
                                if steps == 0:
                                    sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
                                    sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)

                                sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
                                y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
                                                             h.sampling_rate, h.hop_size, h.win_size,
                                                             h.fmin, h.fmax)
                                sw.add_figure('generated/y_hat_spec_{}'.format(j),
                                              plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)

                        val_err = val_err_tot / (j+1)
                        sw.add_scalar("validation/mel_spec_error", val_err, steps)

                    generator.train()

            steps += 1

        scheduler_g.step()
        scheduler_d.step()
        
        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
예제 #28
0
파일: lidc.py 프로젝트: lobantseff/rls-med
        im1 = ax[0, 1].imshow(clip_scan[:, :, clip_scan.shape[2] // 2],
                              cmap="gray")
        ax[0, 1].axis('off')
        divider = make_axes_locatable(ax[0, 1])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im1, cax=cax, orientation='vertical')

        im15 = ax[0, 2].imshow(mask)
        ax[0, 2].axis('off')
        divider = make_axes_locatable(ax[0, 2])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im15, cax=cax, orientation='vertical')

        im2 = ax[1, 0].imshow(img_01.numpy()[0], cmap="gray")
        ax[1, 0].axis('off')
        divider = make_axes_locatable(ax[1, 0])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im2, cax=cax, orientation='vertical')

        im3 = ax[1, 1].imshow(scan[:, :, scan.shape[2] // 2], cmap="gray")
        ax[1, 1].axis('off')
        divider = make_axes_locatable(ax[1, 1])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im3, cax=cax, orientation='vertical')

        fig.suptitle(patient_id)
        fig.tight_layout()
        writer.add_figure("sample_fig", fig, i)

    print("Done")
예제 #29
0
def evaluate(init_lr=0.1,
             max_steps=320,
             mode='rgb',
             batch_size=20,
             save_model=''):

    logger = SummaryWriter()

    if args.model == 'i3d' or args.model == 'bpi3d' or args.model == 'di3d' or args.model == 'mbi3d':
        scale_size = 224
    elif args.model == 'r2plus1d' or args.model == 'w3d' or args.model == 'bpc3d':
        scale_size = 112
    elif args.model == 'tsn' or args.model == 'bptsn':
        scale_size = 224  # 299
    else:
        raise Exception('Model %s not implemented' % args.model)

    if args.model == 'i3d' or args.model == 'r2plus1d' or args.model == 'bpi3d' or args.model == 'bpc3d' \
        or args.model == 'di3d' or args.model == 'mbi3d':
        if args.mode == 'rgb':
            mean = [0.5, 0.5, 0.5]
            std = [0.5, 0.5, 0.5]
        elif args.mode == 'flow':
            mean = [0.5, 0.5]
            std = [0.03, 0.03]
        elif args.mode == 'rgb+flow':
            mean = [0.5, 0.5, 0.5, 0.5, 0.5]
            std = [0.5, 0.5, 0.5, 0.03, 0.03]

    elif args.model == 'tsn' or args.model == 'bptsn':
        if args.mode == 'rgb':
            mean = [104. / 255., 117. / 255., 128. / 255.]
            std = [1., 1., 1.]
        else:
            mean = [0.5, 0.5, 0.5]
            std = [1., 1., 1.]

    else:
        raise Exception('Model %s not implemented' % args.model)

    if args.extract_scores:
        train_transforms = Compose([
            MultiScaleRandomCrop([1.0, 0.95, 0.95 * 0.95], scale_size),
            RandomHorizontalFlip(),
            ToTensor(255.0),
            Normalize(mean, std)
        ])

        test_transforms = Compose(
            [CenterCrop(scale_size),
             ToTensor(255.0),
             Normalize(mean, std)])

        clip_len = args.clip_len // args.sample_step
        temporal_transforms = Compose(
            [TemporalRandomCrop(clip_len),
             RepeatPadding(clip_len)])

        #dataset = Dataset(train_split, 'training', root, mode, train_transforms)
        dataset = PEV(data_root,
                      '/home/lizhongguo/dataset/pev_split/train_split_%d.txt' %
                      split_idx,
                      'training',
                      n_samples_for_each_video=args.n_samples,
                      spatial_transform=train_transforms,
                      temporal_transform=temporal_transforms,
                      target_transform=VideoID(),
                      sample_duration=clip_len,
                      sample_freq=args.sample_freq,
                      mode=args.mode,
                      sample_step=args.sample_step,
                      view=args.view)

        if args.distributed:
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            sampler = None

        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size,
                                                 shuffle=(sampler is None),
                                                 num_workers=2,
                                                 pin_memory=True,
                                                 sampler=sampler,
                                                 drop_last=False)

        val_temporal_transforms = Compose(
            [TemporalBeginCrop(clip_len),
             RepeatPadding(clip_len)])
        val_dataset = PEV(
            data_root,
            '/home/lizhongguo/dataset/pev_split/val_split_%d.txt' % split_idx,
            'evaluation',
            args.n_samples,
            spatial_transform=test_transforms,
            temporal_transform=val_temporal_transforms,
            target_transform=VideoID(),
            sample_duration=clip_len,
            sample_freq=args.sample_freq,
            mode=args.mode,
            sample_step=args.sample_step,
            view=args.view)

        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset)
        else:
            val_sampler = None

        val_dataloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=(val_sampler is None),
            num_workers=2,
            pin_memory=True,
            sampler=val_sampler,
            drop_last=False)

        # setup the model
        if args.model == 'mbi3d' and args.mode == 'rgb+flow':
            model_f = MBI3D(7, args.fuse, ['fflow', 'sflow'])
            model_s = MBI3D(7, args.fuse, ['frgb', 'srgb'])
            checkpoint = torch.load(
                        'pev_split_%d_%s_%s_best%sfs.pt' % (args.split_idx, args.model, 'flow', \
                            '_' if args.fuse=='cbp' else '_'+args.fuse+'_'), map_location=lambda storage, loc: storage)
            model_f.load_state_dict(checkpoint['state_dict'])
            checkpoint = torch.load(
                        'pev_split_%d_%s_%s_best%sfs.pt' % (args.split_idx, args.model, 'rgb', \
                            '_' if args.fuse=='cbp' else '_'+args.fuse+'_'), map_location=lambda storage, loc: storage)
            model_s.load_state_dict(checkpoint['state_dict'])
            model_f.cuda()
            model_s.cuda()
            model_f.train(False)
            model_s.train(False)
        elif args.mode == 'rgb' or args.mode == 'flow':
            model_f = InceptionI3d(num_classes=7,
                                   in_channels=2 if args.mode == 'flow' else 3)
            model_s = InceptionI3d(num_classes=7,
                                   in_channels=2 if args.mode == 'flow' else 3)
            checkpoint = torch.load('pev_split_%d_%s_%s_best_cat_f.pt' %
                                    (args.split_idx, args.model, args.mode),
                                    map_location=lambda storage, loc: storage)
            model_f.load_state_dict(checkpoint['state_dict'])
            checkpoint = torch.load('pev_split_%d_%s_%s_best_cat_s.pt' %
                                    (args.split_idx, args.model, args.mode),
                                    map_location=lambda storage, loc: storage)
            model_s.load_state_dict(checkpoint['state_dict'])
            model_f.cuda()
            model_s.cuda()
            model_f.train(False)
            model_s.train(False)

        else:
            model_flow_f = InceptionI3d(num_classes=7, in_channels=2)
            model_flow_s = InceptionI3d(num_classes=7, in_channels=2)
            model_rgb_f = InceptionI3d(num_classes=7, in_channels=3)
            model_rgb_s = InceptionI3d(num_classes=7, in_channels=3)
            checkpoint = torch.load('pev_split_%d_%s_%s_best_cat_f.pt' %
                                    (args.split_idx, args.model, 'flow'),
                                    map_location=lambda storage, loc: storage)
            model_flow_f.load_state_dict(checkpoint['state_dict'])
            checkpoint = torch.load('pev_split_%d_%s_%s_best_cat_s.pt' %
                                    (args.split_idx, args.model, 'flow'),
                                    map_location=lambda storage, loc: storage)
            model_flow_s.load_state_dict(checkpoint['state_dict'])
            checkpoint = torch.load('pev_split_%d_%s_%s_best_cat_f.pt' %
                                    (args.split_idx, args.model, 'rgb'),
                                    map_location=lambda storage, loc: storage)
            model_rgb_f.load_state_dict(checkpoint['state_dict'])
            checkpoint = torch.load('pev_split_%d_%s_%s_best_cat_s.pt' %
                                    (args.split_idx, args.model, 'rgb'),
                                    map_location=lambda storage, loc: storage)
            model_rgb_s.load_state_dict(checkpoint['state_dict'])

            model_flow_f.cuda()
            model_flow_f.train(False)
            model_flow_s.cuda()
            model_flow_s.train(False)
            model_rgb_f.cuda()
            model_rgb_f.train(False)
            model_rgb_s.cuda()
            model_rgb_s.train(False)

        pred_result = []
        with torch.no_grad():
            for _ in range(args.epochs):
                for data in tqdm(dataloader):
                    if args.model == 'mbi3d' and args.mode == 'rgb+flow':
                        input_f, input_s, labels, _ = data
                        input_f, input_s = input_f.cuda(), input_s.cuda()
                        labels = labels.cuda(non_blocking=True)
                        output = torch.cat([
                            F.softmax(model_f(input_f[:, 3:, :, :, :],
                                              input_s[:, 3:, :, :, :]),
                                      dim=1),
                            F.softmax(model_s(input_f[:, :3, :, :, :],
                                              input_s[:, :3, :, :, :]),
                                      dim=1)
                        ],
                                           dim=1)
                        for o, i in zip(output, labels):
                            pred_result.append(
                                [o.cpu().numpy(),
                                 i.cpu().numpy()])
                    elif args.mode == 'rgb+flow':
                        input_f, input_s, labels, _ = data
                        input_f, input_s = input_f.cuda(), input_s.cuda()
                        labels = labels.cuda(non_blocking=True)
                        output = torch.cat([F.softmax(model_flow_f(input_f[:,3:,:,:,:]), dim=1), \
                            F.softmax(model_flow_s(input_s[:,3:,:,:,:]), dim=1), \
                            F.softmax(model_rgb_f(input_f[:,:3,:,:,:]), dim=1), \
                            F.softmax(model_rgb_s(input_s[:,:3,:,:,:]), dim=1)], dim=1)
                        for o, i in zip(output, labels):
                            pred_result.append(
                                [o.cpu().numpy(),
                                 i.cpu().numpy()])
                    else:
                        input_f, input_s, labels, _ = data
                        input_f, input_s = input_f.cuda(), input_s.cuda()
                        labels = labels.cuda(non_blocking=True)
                        output = torch.cat([F.softmax(model_f(input_f), dim=1), \
                            F.softmax(model_s(input_s), dim=1)], dim=1)

                        for o, i in zip(output, labels):
                            pred_result.append(
                                [o.cpu().numpy(),
                                 i.cpu().numpy()])

        torch.save(
            pred_result, '%s_split_%d_%s_%s_%s_%strain_scores.pt' %
            ('pev', args.split_idx, args.model, args.mode, args.view,
             '' if args.fuse == 'cbp' else args.fuse + '_'))

        val_pred_result = []
        with torch.no_grad():
            for data in tqdm(val_dataloader):
                if args.model == 'mbi3d' and args.mode == 'rgb+flow':
                    input_f, input_s, labels, _ = data
                    input_f, input_s = input_f.cuda(), input_s.cuda()
                    labels = labels.cuda(non_blocking=True)
                    output = torch.cat([
                        F.softmax(model_f(input_f[:, 3:, :, :, :],
                                          input_s[:, 3:, :, :, :]),
                                  dim=1),
                        F.softmax(model_s(input_f[:, :3, :, :, :],
                                          input_s[:, :3, :, :, :]),
                                  dim=1)
                    ],
                                       dim=1)
                    for o, i in zip(output, labels):
                        val_pred_result.append(
                            [o.cpu().numpy(), i.cpu().numpy()])
                elif args.mode == 'rgb+flow':
                    input_f, input_s, labels, _ = data
                    input_f, input_s = input_f.cuda(), input_s.cuda()
                    labels = labels.cuda(non_blocking=True)
                    output = torch.cat([F.softmax(model_flow_f(input_f[:,3:,:,:,:]), dim=1), \
                        F.softmax(model_flow_s(input_s[:,3:,:,:,:]), dim=1), \
                        F.softmax(model_rgb_f(input_f[:,:3,:,:,:]), dim=1), \
                        F.softmax(model_rgb_s(input_s[:,:3,:,:,:]), dim=1)], dim=1)
                    for o, i in zip(output, labels):
                        val_pred_result.append(
                            [o.cpu().numpy(), i.cpu().numpy()])
                else:
                    input_f, input_s, labels, _ = data
                    input_f, input_s = input_f.cuda(), input_s.cuda()
                    labels = labels.cuda(non_blocking=True)
                    output = torch.cat([F.softmax(model_f(input_f), dim=1), \
                        F.softmax(model_s(input_s), dim=1)], dim=1)

                    for o, i in zip(output, labels):
                        val_pred_result.append(
                            [o.cpu().numpy(), i.cpu().numpy()])

        torch.save(
            val_pred_result, '%s_split_%d_%s_%s_%s_%sval_scores.pt' %
            ('pev', args.split_idx, args.model, args.mode, args.view,
             '' if args.fuse == 'cbp' else args.fuse + '_'))

        return
    test_transforms = Compose(
        [CenterCrop(scale_size),
         ToTensor(255.0),
         Normalize(mean, std)])
    clip_len = args.clip_len // args.sample_step

    temporal_transforms = Compose(
        [TemporalBeginCrop(clip_len),
         RepeatPadding(clip_len)])
    target_transforms = VideoID()
    val_dataset = PEV(data_root,
                      '/home/lizhongguo/dataset/pev_split/val_split_%d.txt' %
                      split_idx,
                      'evaluation',
                      args.n_samples,
                      spatial_transform=test_transforms,
                      temporal_transform=temporal_transforms,
                      target_transform=target_transforms,
                      sample_duration=clip_len,
                      sample_freq=args.sample_freq,
                      mode=args.mode,
                      sample_step=args.sample_step,
                      view=args.view)
    #val_dataset.random_select = True

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=2,
                                                 pin_memory=True,
                                                 drop_last=False)

    # setup the model
    model, _, _ = model_builder()
    model.train(False)

    if args.visualize:
        model.add_logger(logger, 'Conv3d_1a_7x7')

    top1 = AverageMeter()
    top2 = AverageMeter()
    label_names = ['Pit', 'Att', 'Pas', 'Rec', 'Pos', 'Neg', 'Ges']
    confusion_matrix = MatrixMeter(label_names)

    pred_result = dict()
    with torch.no_grad():
        for data in tqdm(val_dataloader):
            if args.model == 'di3d':
                input_f, input_s, labels, _ = data
                input_f, input_s = input_f.cuda(), input_s.cuda()
                labels = labels.cuda(non_blocking=True)
                output = model(input_f, input_s)

            elif args.model == 'mbi3d':
                if args.view == 'fs':
                    input_f, input_s, labels, _ = data
                    input_f, input_s = input_f.cuda(), input_s.cuda()
                    labels = labels.cuda(non_blocking=True)

                    if args.mode == 'rgb+flow':
                        output = model(input_f[:, :3, :, :, :],
                                       input_f[:, 3:, :, :, :],
                                       input_s[:, :3, :, :, :],
                                       input_s[:, 3:, :, :, :])
                    else:
                        output = model(input_f, input_s)
                elif args.view == 'f' or args.view == 's':
                    assert args.mode == 'rgb+flow'
                    inputs, labels, _ = data
                    inputs = inputs.cuda()
                    labels = labels.cuda(non_blocking=True)
                    output = model(inputs[:, :3, :, :, :], inputs[:,
                                                                  3:, :, :, :])

            else:
                inputs, labels, _ = data
                inputs = inputs.cuda()
                labels = labels.cuda(non_blocking=True)
                output = model(inputs)

            output = F.softmax(output, dim=1)

            for o, i in zip(output, labels):
                if i not in pred_result:
                    pred_result[i] = []
                pred_result[i].append(o)

        torch.save(
            pred_result, '%s_split_%d_%s_%s_%s_%s_result.pt' %
            ('pev', args.split_idx, args.model, args.mode, args.fuse,
             args.view))

        for i in pred_result:
            avg_pred = torch.stack(tuple(o for o in pred_result[i]),
                                   dim=0).mean(dim=0)
            target = val_dataset.id2label[i.item()]
            _, prediction = avg_pred.topk(2)
            prediction = prediction.tolist()
            if target == prediction[0]:
                top1.update(1., n=1)
            else:
                top1.update(0., n=1)

            if target in prediction:
                top2.update(1., n=1)
            else:
                top2.update(0., n=1)

            confusion_matrix.update(prediction[0], target)

    logger.add_scalar('val/top1', 100 * top1.avg, 0)
    logger.add_scalar('val/top2', 100 * top2.avg, 0)
    logger.add_figure('val/confusion',
                      draw_confusion_matrix(confusion_matrix._data,
                                            label_names),
                      0,
                      close=False)

    print("Top1:%.2f Top2:%.2f" % (confusion_matrix.acc * 100, top2.avg * 100))
    print(confusion_matrix)

    logger.close()
예제 #30
0
def main(config: ConfigParser):
    # Set up logging
    tensorboard = SummaryWriter(config.log_dir_test)
    logger = config.get_logger('test')

    # Setup data_loader instances
    if config['data_loader']['iterator']:
        data_loader = config.init_obj_from_file('data_loader', test=True)
        test_loader = data_loader.split_test()

    else:
        raise NotImplementedError

    # Build model architecture
    model = config.init_obj_from_file('arch')
    logger.info(model)

    # Check the dimensions
    assert len(data_loader.SRC.vocab) == config['arch']['args']['input_dim'], "Input dimensions need to match"
    assert len(data_loader.TRG.vocab) == config['arch']['args']['output_dim'], "Output dimensions need to match"

    # Load the model parameters
    logger.info('Loading checkpoint: {} ...'.format(config.resume))
    checkpoint = torch.load(config.resume)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # Get function handles of loss and metrics
    loss_fn = getattr(module_loss, config['loss']['function'])

    # Set the padding index in the criterion such that we ignore pad tokens
    if config['loss']['padding_idx']:
        loss_fn = loss_fn(data_loader.TRG.vocab.stoi['<pad>'])

    if 'packed' in config['arch'] and config['arch']['packed']:
        model.set_tokens(data_loader.SRC.vocab.stoi['<pad>'],
                         data_loader.TRG.vocab.stoi['<sos>'],
                         data_loader.TRG.vocab.stoi['<eos>'])

    # Prepare model for testing
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    print("Computing Loss...")
    total_loss = 0.0
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader)):
            output, target = model.process_batch(batch, train=False)

            # computing loss, metrics on test set
            loss = loss_fn(output, target)
            batch_size = output.shape[0]
            total_loss += loss.item() * batch_size

    n_samples = len(test_loader)
    print(f"Samples: {n_samples}")
    avg_loss = total_loss / n_samples
    log = {
        'loss':       avg_loss,
        'perplexity': np.exp(avg_loss)
    }
    logger.info(log)

    # Set up the translate function
    translate = get_translation_fn(model, device, data_loader)

    # Blue score init
    hypotheses = list()
    references = list()
    # Adds epsilon to bleu counts
    smoother = SmoothingFunction().method1

    files = config['inference']

    # Pretty print interference information
    logger.info(f"Starting inference")
    for key, value in files.items():
        t, file = str(key).split('_')
        logger.info(f'\t {t.capitalize():5s} {file.capitalize():5s} \t: {value}')


    # Keep track of a src, trg and predictions file
    with open(files['src_file'], encoding='utf-8') as src_file, \
            open(files['trg_file'], encoding='utf-8') as trg_file, \
            open(files['pred_file'], "a+", encoding='utf-8') as pred_file:

        # For each file from the src and trg
        for idx, (src_sent, trg_sent) in tqdm(enumerate(zip(src_file, trg_file))):
            # Strip white space and convert to lower
            src_sent = src_sent.strip().lower()
            trg_sent = trg_sent.strip().lower()

            # Tokenize
            src_tokens = data_loader.src_tokenize(src_sent)
            trg_tokens = data_loader.trg_tokenize(trg_sent)

            # Translate with the trained model
            pred_tokens, attention = translate(src_tokens)
            pred_sent = " ".join(pred_tokens)

            # Save results to tensorboard
            tensorboard.add_text(f"{idx}/src", src_sent)
            tensorboard.add_text(f"{idx}/trg", trg_sent)
            tensorboard.add_text(f"{idx}/pred", pred_sent)
            tensorboard.add_scalar(f"{idx}/sent_bleu",
                                   sentence_bleu(trg_tokens, pred_tokens, smoothing_function=smoother))

            references.append(trg_tokens)
            hypotheses.append(pred_tokens)

            # Visualize the attention
            att_fig = display_attention(src_tokens, pred_tokens, attention)
            tensorboard.add_figure(f"{idx}/attention", att_fig)
            plt.close(att_fig)

            # Print to predictions file
            print(pred_sent, file=pred_file)

    # Compute the corpus bleu score
    tensorboard.add_scalar("test/corpus_blue", corpus_bleu(hypotheses, references, smoothing_function=smoother))