Пример #1
0
def run_train(dataset_path,
              exp_name,
              max_shapes,
              epochs,
              hidden_dim,
              eval_per,
              variational,
              loss_config,
              enc_lr,
              dec_lr,
              enc_step,
              dec_step,
              enc_decay,
              dec_decay,
              batch_size,
              holdout_perc,
              rd_seed,
              print_per,
              num_gen,
              num_eval,
              keep_missing,
              category,
              load_epoch=None):

    random.seed(rd_seed)
    np.random.seed(rd_seed)
    torch.manual_seed(rd_seed)

    raw_indices, progs = load_progs(dataset_path, max_shapes)

    inds_and_progs = list(zip(raw_indices, progs))
    random.shuffle(inds_and_progs)

    inds_and_progs = inds_and_progs[:max_shapes]

    decoder = FDGRU(hidden_dim)
    decoder.to(device)

    encoder = ENCGRU(hidden_dim)
    encoder.to(device)

    print('Converting progs to tensors')

    samples = []
    for ind, prog in tqdm(inds_and_progs):
        nprog = progToData(prog)
        samples.append((nprog, ind))

    dec_opt = torch.optim.Adam(decoder.parameters(), lr=dec_lr, eps=ADAM_EPS)

    enc_opt = torch.optim.Adam(encoder.parameters(), lr=enc_lr, eps=ADAM_EPS)

    dec_sch = torch.optim.lr_scheduler.StepLR(dec_opt,
                                              step_size=dec_step,
                                              gamma=dec_decay)

    enc_sch = torch.optim.lr_scheduler.StepLR(enc_opt,
                                              step_size=enc_step,
                                              gamma=enc_decay)

    train_ind_file = f'data_splits/{category}/train.txt'
    val_ind_file = f'data_splits/{category}/val.txt'

    train_samples = []
    val_samples = []

    train_inds = getInds(train_ind_file)
    val_inds = getInds(val_ind_file)

    misses = 0.

    for (prog, ind) in samples:
        if ind in train_inds:
            train_samples.append((prog, ind))
        elif ind in val_inds:
            val_samples.append((prog, ind))
        else:
            if keep_missing:
                kept += 1
                if random.random() < holdout_perc:
                    val_samples.append((prog, ind))
                else:
                    train_samples.append((prog, ind))
            else:
                misses += 1

    print(f"Samples missed: {misses}")
    train_num = len(train_samples)
    val_num = len(val_samples)

    train_dataset = DataLoader(train_samples,
                               batch_size,
                               shuffle=True,
                               collate_fn=_col)
    eval_train_dataset = DataLoader(train_samples[:num_eval],
                                    batch_size=1,
                                    shuffle=False,
                                    collate_fn=_col)
    val_dataset = DataLoader(val_samples,
                             batch_size,
                             shuffle=False,
                             collate_fn=_col)
    eval_val_dataset = DataLoader(val_samples[:num_eval],
                                  batch_size=1,
                                  shuffle=False,
                                  collate_fn=_col)

    utils.log_print(f"Training size: {train_num}",
                    f"{outpath}/{exp_name}/log.txt")
    utils.log_print(f"Validation size: {val_num}",
                    f"{outpath}/{exp_name}/log.txt")

    with torch.no_grad():
        gt_gen_results, _ = metrics.gen_metrics(
            [s[0] for s in val_samples[:num_eval]], '', '', '', VERBOSE, False)

    utils.log_print(
        f""" 
  GT Val Number of parts = {gt_gen_results['num_parts']}
  GT Val Variance = {gt_gen_results['variance']}
  GT Val Rootedness = {gt_gen_results['rootedness']}
  GT Val Stability = {gt_gen_results['stability']}
""", f"{outpath}/{exp_name}/log.txt")

    aepochs = []

    train_res_plots = {}
    val_res_plots = {}
    gen_res_plots = {}
    eval_res_plots = {'train': {}, 'val': {}}

    print('training ...')

    if load_epoch is None:
        start = 0
    else:
        start = load_epoch + 1

    for e in range(start, epochs):
        do_print = (e + 1) % print_per == 0
        t = time.time()
        if do_print:
            utils.log_print(f"\nEpoch {e}:", f"{outpath}/{exp_name}/log.txt")

        train_ep_result = model_train_results(train_dataset, encoder, decoder,
                                              dec_opt, enc_opt, variational,
                                              loss_config, 'train', do_print,
                                              exp_name)

        dec_sch.step()
        enc_sch.step()

        if do_print:
            utils.log_print(f"  Train Epoch Time = {time.time() - t}",
                            f"{outpath}/{exp_name}/log.txt")

        if (e + 1) % eval_per == 0:

            with torch.no_grad():
                t = time.time()
                utils.log_print(f"Doing Evaluation",
                                f"{outpath}/{exp_name}/log.txt")

                val_ep_result = model_train_results(val_dataset, encoder,
                                                    decoder, None, None, False,
                                                    loss_config, 'val', True,
                                                    exp_name)

                eval_results, gen_results = model_eval(eval_train_dataset,
                                                       eval_val_dataset,
                                                       encoder, decoder,
                                                       exp_name, e, num_gen)

                for name, named_results in eval_results:
                    if named_results['nc'] > 0:
                        named_results['cub_prm'] /= named_results['nc']

                    if named_results['na'] > 0:
                        named_results['xyz_prm'] /= named_results['na']
                        named_results['cubc'] /= named_results['na']

                    if named_results['count'] > 0:
                        named_results['bb'] /= named_results['count']

                    if named_results['nl'] > 0:
                        named_results['cmdc'] /= named_results['nl']

                    if named_results['ns'] > 0:
                        named_results['sym_cubc'] /= named_results['ns']
                        named_results['axisc'] /= named_results['ns']

                    if named_results['np'] > 0:
                        named_results['corr_line_num'] /= named_results['np']
                        named_results['bad_leaf'] /= named_results['np']

                    if named_results['nsq'] > 0:
                        named_results['uv_prm'] /= named_results['nsq']
                        named_results['sq_cubc'] /= named_results['nsq']
                        named_results['facec'] /= named_results['nsq']

                    if named_results['nap'] > 0:
                        named_results['palignc'] /= named_results['nap']

                    if named_results['nan'] > 0:
                        named_results['nalignc'] /= named_results['nan']

                    named_results.pop('nc')
                    named_results.pop('nan')
                    named_results.pop('nap')
                    named_results.pop('na')
                    named_results.pop('ns')
                    named_results.pop('nsq')
                    named_results.pop('nl')
                    named_results.pop('count')
                    named_results.pop('np')
                    named_results.pop('cub')
                    named_results.pop('sym_cub')
                    named_results.pop('axis')
                    named_results.pop('cmd')
                    named_results.pop('miss_hier_prog')

                    utils.log_print(
                        f"""

  Evaluation on {name} set:
                  
  Eval {name} F-score = {named_results['fscores']}
  Eval {name} IoU = {named_results['iou_shape']}
  Eval {name} PD = {named_results['param_dist_parts']}
  Eval {name} Prog Creation Perc: {named_results['prog_creation_perc']}
  Eval {name} Cub Prm Loss = {named_results['cub_prm']} 
  Eval {name} XYZ Prm Loss = {named_results['xyz_prm']}
  Eval {name} UV Prm Loss = {named_results['uv_prm']}
  Eval {name} Sym Prm Loss = {named_results['sym_prm']}
  Eval {name} BBox Loss = {named_results['bb']}
  Eval {name} Cmd Corr % {named_results['cmdc']}
  Eval {name} Cub Corr % {named_results['cubc']}
  Eval {name} Squeeze Cub Corr % {named_results['sq_cubc']}
  Eval {name} Face Corr % {named_results['facec']}
  Eval {name} Pos Align Corr % {named_results['palignc']}
  Eval {name} Neg Align Corr % {named_results['nalignc']}
  Eval {name} Sym Cub Corr % {named_results['sym_cubc']}
  Eval {name} Sym Axis Corr % {named_results['axisc']}
  Eval {name} Corr Line # % {named_results['corr_line_num']}
  Eval {name} Bad Leaf % {named_results['bad_leaf']}

""", f"{outpath}/{exp_name}/log.txt")

                utils.log_print(
                    f"""
  Gen Prog creation % = {gen_results['prog_creation_perc']}
  Gen Number of parts = {gen_results['num_parts']}
  Gen Variance = {gen_results['variance']}
  Gen Rootedness = {gen_results['rootedness']}
  Gen Stability = {gen_results['stability']}
""", f"{outpath}/{exp_name}/log.txt")

                utils.log_print(f"Eval Time = {time.time() - t}",
                                f"{outpath}/{exp_name}/log.txt")

                # Plotting logic

                for key in train_ep_result:
                    res = train_ep_result[key]
                    if torch.is_tensor(res):
                        res = res.detach().item()
                    if not key in train_res_plots:
                        train_res_plots[key] = [res]
                    else:
                        train_res_plots[key].append(res)

                for key in val_ep_result:
                    res = val_ep_result[key]
                    if torch.is_tensor(res):
                        res = res.detach().item()
                    if not key in val_res_plots:
                        val_res_plots[key] = [res]
                    else:
                        val_res_plots[key].append(res)

                for key in gen_results:
                    res = gen_results[key]
                    if torch.is_tensor(res):
                        res = res.detach().item()
                    if not key in gen_res_plots:
                        gen_res_plots[key] = [res]
                    else:
                        gen_res_plots[key].append(res)

                for name, named_results in eval_results:
                    for key in named_results:
                        res = named_results[key]
                        if torch.is_tensor(res):
                            res = res.detach().item()
                        if not key in eval_res_plots[name]:
                            eval_res_plots[name][key] = [res]
                        else:
                            eval_res_plots[name][key].append(res)

                aepochs.append(e)

                for key in train_res_plots:
                    plt.clf()
                    plt.plot(aepochs, train_res_plots[key], label='train')
                    if key in val_res_plots:
                        plt.plot(aepochs, val_res_plots[key], label='val')
                    plt.legend()
                    if key == "recon":
                        plt.yscale('log')
                    plt.grid()
                    plt.savefig(f"{outpath}/{exp_name}/plots/train/{key}.png")

                for key in gen_res_plots:
                    plt.clf()
                    plt.plot(aepochs, gen_res_plots[key])
                    if key == "variance":
                        plt.yscale('log')
                    plt.grid()
                    plt.savefig(f"{outpath}/{exp_name}/plots/gen/{key}.png")

                for key in eval_res_plots['train']:
                    plt.clf()
                    t_p, = plt.plot(aepochs,
                                    eval_res_plots['train'][key],
                                    label='train')

                    if 'val' in eval_res_plots:
                        if key in eval_res_plots['val']:
                            v_p, = plt.plot(aepochs,
                                            eval_res_plots['val'][key],
                                            label='val')
                            plt.legend(handles=[t_p, v_p])
                    plt.grid()
                    plt.savefig(f"{outpath}/{exp_name}/plots/eval/{key}.png")

            try:
                if SAVE_MODELS:
                    utils.log_print("Saving Models",
                                    f"{outpath}/{exp_name}/log.txt")
                    # TODO: torch.save(x.state_dict(), so only the model parameters get saved (along with their names))
                    torch.save(decoder,
                               f"{outpath}/{exp_name}/models/decoder_{e}.pt")
                    torch.save(encoder,
                               f"{outpath}/{exp_name}/models/encoder_{e}.pt")
            except Exception as e:
                utils.log_print(f"Couldnt save models for {e}",
                                f"{outpath}/{exp_name}/log.txt")
Пример #2
0
    def train(self):
        with tf.Session() as sess:
            tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, self.__bert_checkpoint_path)
            print("init bert model params")
            tf.train.init_from_checkpoint(self.__bert_checkpoint_path,
                                          assignment_map)
            print("init bert model params done")
            sess.run(tf.variables_initializer(tf.global_variables()))

            current_step = 0
            start = time.time()
            for epoch in range(self.config["epochs"]):
                print("----- Epoch {}/{} -----".format(epoch + 1,
                                                       self.config["epochs"]))

                for batch in self.data_obj.next_batch(self.t_in_ids,
                                                      self.t_in_masks,
                                                      self.t_seg_ids,
                                                      self.t_lab_ids,
                                                      self.t_seq_len):

                    loss, true_y, predictions = self.model.train(
                        sess, batch, self.config["keep_prob"])

                    f1, precision, recall = gen_metrics(
                        pred_y=predictions,
                        true_y=true_y,
                        label_to_index=self.lab_to_idx)
                    print(
                        "train: step: {}, loss: {}, recall: {}, precision: {}, f1: {}"
                        .format(current_step, loss, recall, precision, f1))

                    current_step += 1
                    if self.data_obj and current_step % self.config[
                            "checkpoint_every"] == 0:

                        eval_losses = []
                        eval_recalls = []
                        eval_precisions = []
                        eval_f1s = []
                        for eval_batch in self.data_obj.next_batch(
                                self.e_in_ids, self.e_in_masks, self.e_seg_ids,
                                self.e_lab_ids, self.e_seq_len):
                            eval_loss, eval_true_y, eval_predictions = self.model.eval(
                                sess, eval_batch)

                            eval_losses.append(eval_loss)

                            f1, precision, recall = gen_metrics(
                                pred_y=eval_predictions,
                                true_y=eval_true_y,
                                label_to_index=self.lab_to_idx)
                            eval_recalls.append(recall)
                            eval_precisions.append(precision)
                            eval_f1s.append(f1)
                        print("\n")
                        print(
                            "eval:  loss: {}, recall: {}, precision: {}, f1: {}"
                            .format(mean(eval_losses), mean(eval_recalls),
                                    mean(eval_precisions), mean(eval_f1s)))
                        print("\n")

                        if self.config["ckpt_model_path"]:
                            save_path = self.config["ckpt_model_path"]
                            if not os.path.exists(save_path):
                                os.makedirs(save_path)
                            model_save_path = os.path.join(
                                save_path, self.config["model_name"])
                            self.model.saver.save(sess,
                                                  model_save_path,
                                                  global_step=current_step)

            end = time.time()
            print("total train time: ", end - start)
Пример #3
0
def model_eval(eval_train_dataset, eval_val_dataset, encoder, decoder,
               exp_name, epoch, num_gen):
    decoder.eval()
    encoder.eval()

    eval_results = []

    for name, dataset in [('train', eval_train_dataset),
                          ('val', eval_val_dataset)]:

        if len(dataset) == 0:
            continue

        named_results = {'count': 0., 'miss_hier_prog': 0.}

        recon_sets = []

        for batch in dataset:
            for shape in batch:

                named_results[f'count'] += 1.

                # Always get maximum likelihood estimation (i.e. mean) of shape encoding at eval time
                encoding, _ = get_encoding(shape[0], encoder, mle=True)

                prog, shape_result = run_eval_decoder(encoding, decoder, False,
                                                      shape[0])

                for key in shape_result:
                    nkey = f'{key}'
                    if nkey not in named_results:
                        named_results[nkey] = shape_result[key]
                    else:
                        named_results[nkey] += shape_result[key]

                if prog is None:
                    named_results[f'miss_hier_prog'] += 1.
                    continue

                recon_sets.append((prog, shape[0], shape[1]))

        # For reconstruction, get metric performance
        recon_results, recon_misses = metrics.recon_metrics(
            recon_sets, outpath, exp_name, name, epoch, VERBOSE)

        for key in recon_results:
            named_results[key] = recon_results[key]

        named_results[f'miss_hier_prog'] += recon_misses

        named_results[f'prog_creation_perc'] = (
            named_results[f'count'] -
            named_results[f'miss_hier_prog']) / named_results[f'count']

        eval_results.append((name, named_results))

    gen_progs = []

    gen_prog_fails = 0.

    # Also generate a set of unconditional ShapeAssembly Programs
    for i in range(0, num_gen):
        try:
            h0 = torch.randn(1, 1, args.hidden_dim).to(device)
            prog, _ = run_eval_decoder(h0, decoder, True)
            gen_progs.append(prog)

        except Exception as e:
            gen_prog_fails += 1.

            if VERBOSE:
                print(f"Failed generating new program with {e}")

    # Get metrics for unconditional generations
    gen_results, gen_misses = metrics.gen_metrics(gen_progs, outpath, exp_name,
                                                  epoch, VERBOSE)

    if num_gen > 0:
        gen_results['prog_creation_perc'] = (num_gen - gen_misses -
                                             gen_prog_fails) / num_gen

    else:
        gen_results['prog_creation_perc'] = 0.

    return eval_results, gen_results