Esempio n. 1
0
    def train_epoch(epo):
        model.train()
        loss_list = []
        mse_list = []
        likelihood_list = []
        kl_first_p_list = []
        std_first_p_list = []

        torch.cuda.empty_cache()

        for itr in tqdm(range(train_batch)):

            #utils.update_learning_rate(optimizer, decay_rate=0.999, lowest=args.lr / 10)
            wait_until_kl_inc = 10

            if itr < wait_until_kl_inc:
                kl_coef = 0.
            else:
                kl_coef = (1 - 0.99**(itr - wait_until_kl_inc))

            batch_dict_encoder = utils.get_next_batch_new(
                train_encoder, device)

            batch_dict_graph = utils.get_next_batch_new(train_graph, device)

            batch_dict_decoder = utils.get_next_batch(train_decoder, device)

            loss, mse, likelihood, kl_first_p, std_first_p = train_single_batch(
                model, batch_dict_encoder, batch_dict_decoder,
                batch_dict_graph, kl_coef)

            #saving results
            loss_list.append(loss), mse_list.append(
                mse), likelihood_list.append(likelihood)
            kl_first_p_list.append(kl_first_p), std_first_p_list.append(
                std_first_p)

            del batch_dict_encoder, batch_dict_graph, batch_dict_decoder
            #train_res, loss
            torch.cuda.empty_cache()

        scheduler.step()

        message_train = 'Epoch {:04d} [Train seq (cond on sampled tp)] | Loss {:.6f} | MSE {:.6F} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            epo, np.mean(loss_list), np.mean(mse_list),
            np.mean(likelihood_list), np.mean(kl_first_p_list),
            np.mean(std_first_p_list))

        return message_train, kl_coef
Esempio n. 2
0
    num_batches = data_obj["n_train_batches"]

    for itr in range(1, num_batches * (args.niters + 1)):
        optimizer.zero_grad()
        utils.update_learning_rate(optimizer,
                                   decay_rate=0.999,
                                   lowest=args.lr / 10)

        wait_until_kl_inc = 10
        if itr // num_batches < wait_until_kl_inc:
            kl_coef = 0.
        else:
            kl_coef = (1 - 0.99**(itr // num_batches - wait_until_kl_inc))

        batch_dict = utils.get_next_batch(data_obj["train_dataloader"])
        train_res = model.compute_all_losses(
            batch_dict, n_traj_samples=3, kl_coef=kl_coef
        )  # for each elem in a batch, sample 3 trajectories from the same encoded posterior q(z|x)
        train_res["loss"].backward()
        optimizer.step()

        n_iters_to_viz = 1
        if itr % (n_iters_to_viz * num_batches) == 0:
            with torch.no_grad():

                test_res = compute_loss_all_batches(
                    model,
                    data_obj["test_dataloader"],
                    args,
                    n_batches=data_obj["n_test_batches"],
Esempio n. 3
0
def train_it(
        Model,
        Data_obj,
        args,
        file_name,
        ExperimentID,
        #Trainwriter,
        Validationwriter,
        input_command,
        Devices):
    """
	parameters:
		Model, #List of Models
		Data_obj, #List of Data_objects which live on different devices
		args,
		file_name,
		ExperimentID, #List of IDs
		trainwriter, #List of TFwriters
		validationwriter, #List of TFwriters
		input_command,
		Devices #List of devices
	"""

    Ckpt_path = []
    Top_ckpt_path = []
    Best_test_acc = []
    Best_test_acc_step = []
    Logger = []
    Optimizer = []
    otherOptimizer = []
    ODEOptimizer = []

    for i, device in enumerate(Devices):

        Ckpt_path.append(
            os.path.join(args.save,
                         "experiment_" + str(ExperimentID[i]) + '.ckpt'))
        Top_ckpt_path.append(
            os.path.join(
                args.save,
                "experiment_" + str(ExperimentID[i]) + '_topscore.ckpt'))
        Best_test_acc.append(0)
        Best_test_acc_step.append(0)

        log_path = "logs/" + file_name + "_" + str(ExperimentID[i]) + ".log"
        if not os.path.exists("logs/"):
            utils.makedirs("logs/")
        Logger.append(
            utils.get_logger(logpath=log_path,
                             filepath=os.path.abspath(__file__)))
        Logger[i].info(input_command)

        Optimizer.append(
            get_optimizer(args.optimizer, args.lr, Model[i].parameters()))

    num_batches = Data_obj[0]["n_train_batches"]
    labels = Data_obj[0]["dataset_obj"].label_list

    #create empty lists for results and similar
    num_gpus = len(Devices)
    train_res = [None] * num_gpus
    batch_dict = [None] * num_gpus
    test_res = [None] * num_gpus
    label_dict = [None] * num_gpus

    # empty result placeholder
    somedict = {}
    test_res = [somedict]
    test_res[0]["accuracy"] = float(0)

    if args.v == 1 or args.v == 2:
        pbar = tqdm(range(1,
                          num_batches * (args.niters) + 1),
                    position=0,
                    leave=True,
                    ncols=160)
    else:
        pbar = range(1, num_batches * (args.niters) + 1)

    for itr in pbar:

        for i, device in enumerate(Devices):
            Optimizer[i].zero_grad()
        for i, device in enumerate(Devices):
            # default decay_rate = 0.999, lowest= args.lr/10 	# original
            # decay_rate = 0.9995, lowest = args.lr / 50 		# new
            utils.update_learning_rate(Optimizer[i],
                                       decay_rate=args.lrdecay,
                                       lowest=args.lr / 1000)

        wait_until_kl_inc = 10
        if itr // num_batches < wait_until_kl_inc:
            kl_coef = 0.01
        else:
            kl_coef = (1 - 0.99**(itr // num_batches - wait_until_kl_inc))

        for i, device in enumerate(Devices):
            batch_dict[i] = utils.get_next_batch(
                Data_obj[i]["train_dataloader"])

        for i, device in enumerate(Devices):
            train_res[i] = Model[i].compute_all_losses(batch_dict[i],
                                                       n_traj_samples=3,
                                                       kl_coef=kl_coef)

        for i, device in enumerate(Devices):
            train_res[i]["loss"].backward()

        for i, device in enumerate(Devices):
            Optimizer[i].step()

        n_iters_to_viz = 0.333
        if args.dataset == "swisscrop":
            n_iters_to_viz /= 20

        if (itr != 0) and (itr % args.val_freq) == 0:
            with torch.no_grad():

                # Calculate labels and loss on test data
                for i, device in enumerate(Devices):
                    test_res[i], label_dict[i] = compute_loss_all_batches(
                        Model[i],
                        Data_obj[i]["test_dataloader"],
                        args,
                        n_batches=Data_obj[i]["n_test_batches"],
                        experimentID=ExperimentID[i],
                        device=Devices[i],
                        n_traj_samples=3,
                        kl_coef=kl_coef)

                for i, device in enumerate(Devices):

                    #make confusion matrix
                    cm, conf_fig = plot_confusion_matrix(
                        label_dict[0]["correct_labels"],
                        label_dict[0]["predict_labels"],
                        Data_obj[0]["dataset_obj"].label_list,
                        tensor_name='dev/cm')
                    Validationwriter[i].add_figure(
                        "Validation_Confusionmatrix", conf_fig,
                        itr * args.batch_size)

                    # prepare GT labels and predictions
                    y_ref_train = torch.argmax(
                        train_res[0]['label_predictions'],
                        dim=2).squeeze().cpu()
                    y_pred_train = torch.argmax(batch_dict[0]['labels'],
                                                dim=1).cpu()
                    y_ref = label_dict[0]["correct_labels"].cpu()
                    y_pred = label_dict[0]["predict_labels"]

                    # prepare GT labels and predictions
                    y_ref_train = torch.argmax(
                        train_res[0]['label_predictions'],
                        dim=2).squeeze().cpu()
                    y_pred_train = torch.argmax(batch_dict[0]['labels'],
                                                dim=1).cpu()
                    y_ref = label_dict[0]["correct_labels"].cpu()
                    y_pred = label_dict[0]["predict_labels"]

                    #Make checkpoint
                    torch.save(
                        {
                            'args': args,
                            'state_dict': Model[i].state_dict(),
                        }, Ckpt_path[i])

                    if test_res[i]["accuracy"] > Best_test_acc[i]:
                        Best_test_acc[i] = test_res[i]["accuracy"]
                        Best_test_acc_step[i] = itr * args.batch_size
                        torch.save(
                            {
                                'args': args,
                                'state_dict': Model[i].state_dict(),
                                'cm': cm
                            }, Top_ckpt_path[i])

                        #utils.plot_confusion_matrix2(y_ref, y_pred, Data_obj[0]["dataset_obj"].label_list, ExperimentID[i])
                        # Save trajectory here
                        #if not test_res[i]["PCA_traj"] is None:
                        #	with open( os.path.join('vis', 'traj_dict' + str(ExperimentID[i]) + '.pickle' ), 'wb') as handle:
                        #		pickle.dump(test_res[i]["PCA_traj"], handle, protocol=pickle.HIGHEST_PROTOCOL)

                    # make PCA visualization
                    if "PCA_traj" in test_res[0]:
                        #PCA_fig = get_pca_fig(test_res[0]["PCA_traj"]["PCA_trajs1"])
                        PCA_fig = None
                    else:
                        PCA_fig = None

                    logdict = {
                        'Classification_accuracy/train':
                        train_res[i]["accuracy"],
                        'Classification_accuracy/validation':
                        test_res[i]["accuracy"],
                        'Classification_accuracy/validation_peak':
                        Best_test_acc[i],
                        'Classification_accuracy/validation_peak_step':
                        Best_test_acc_step[i],
                        'loss/train':
                        train_res[i]["loss"].detach(),
                        'loss/validation':
                        test_res[i]["loss"].detach(),
                        'Other_metrics/train_cm':
                        sklearn_cm(y_ref_train, y_pred_train),
                        'Other_metrics/train_precision':
                        precision_score(y_ref_train,
                                        y_pred_train,
                                        average='macro'),
                        'Other_metrics/train_recall':
                        recall_score(y_ref_train,
                                     y_pred_train,
                                     average='macro'),
                        'Other_metrics/train_f1':
                        f1_score(y_ref_train, y_pred_train, average='macro'),
                        'Other_metrics/train_kappa':
                        cohen_kappa_score(y_ref_train, y_pred_train),
                        'Other_metrics/validation_cm':
                        sklearn_cm(y_ref, y_pred),
                        'Other_metrics/validation_precision':
                        precision_score(y_ref, y_pred, average='macro'),
                        'Other_metrics/validation_recall':
                        recall_score(y_ref, y_pred, average='macro'),
                        'Other_metrics/validation_f1':
                        f1_score(y_ref, y_pred, average='macro'),
                        'Other_metrics/validation_kappa':
                        cohen_kappa_score(y_ref, y_pred),
                    }

                    if "PCA_traj" in test_res[0]:
                        pass
                        #logdict['Visualization/latent_trajectory'] = wandb.Image( get_pca_fig(test_res[0]["PCA_traj"]) )

                    wandb.log(logdict, step=itr * args.batch_size)

                    # wandb.sklearn.plot_confusion_matrix(y_ref, y_pred, labels)
        # Write training loss and accuracy after every batch (Only recommanded for debugging)
        fine_train_writer = False
        if fine_train_writer:
            if "loss" in train_res[i]:
                Validationwriter[i].add_scalar('loss/train',
                                               train_res[i]["loss"].detach(),
                                               itr * args.batch_size)
            if "accuracy" in train_res[i]:
                Validationwriter[i].add_scalar('Classification_accuracy/train',
                                               train_res[i]["accuracy"],
                                               itr * args.batch_size)

        #update progressbar
        if args.v == 2:
            pbar.set_description(
                "Train Ac: {:.3f} %  |  Test Ac: {:.3f} %, Peak Test Ac.: {:.3f} % (at {} batches)  |"
                .format(train_res[0]["accuracy"] * 100,
                        test_res[0]["accuracy"] * 100, Best_test_acc[i] * 100,
                        Best_test_acc_step[0] // args.batch_size))

        #empty all training variables
        #train_res = [None] * num_gpus
        batch_dict = [None] * num_gpus
        #test_res = [None] * num_gpus
        label_dict = [None] * num_gpus

    print(Best_test_acc[0], " at step ", Best_test_acc_step[0])
    return train_res, test_res, Best_test_acc[0], Best_test_acc_step[0]