Beispiel #1
0
class CometExperimentLogger(ExperimentLogger):
    def __init__(self, exp_name, online=True, **kwargs):
        super(CometExperimentLogger, self).__init__(exp_name, **kwargs)
        if online:
            self.comet = Experiment(project_name=exp_name, **kwargs)
        else:
            self.comet = OfflineExperiment(project_name=exp_name, **kwargs)

    def log_metric(self, tag, value, step, **kwargs):
        self.comet.log_metric(tag, value, step=step, **kwargs)

    def log_image(self, tag, img, step, **kwargs):
        self.comet.log_image(img, name=tag, step=step, **kwargs)

    def log_plt(self, tag, plt, step, **kwargs):
        self.comet.log_figure(figure=plt, figure_name=tag, step=step, **kwargs)

    def log_text(self, tag, text, **kwargs):
        self.comet.log_text(text, **kwargs)

    def log_parameters(self, params, **kwargs):
        self.comet.log_parameters(params, **kwargs)

    def start_epoch(self, **kwargs):
        super(CometExperimentLogger, self).start_epoch()

    def end_epoch(self, **kwargs):
        super(CometExperimentLogger, self).end_epoch()
        self.comet.log_epoch_end(self.epoch, **kwargs)

    def end_experiment(self):
        self.comet.end()
Beispiel #2
0
def _set_comet_experiment(configuration, config_key):
    experiment = OfflineExperiment(
        project_name='general',
        workspace='benjaminbenoit',
        offline_directory="../damic_comet_experiences")
    experiment.set_name(config_key)
    experiment.log_parameters(configuration)
    return experiment
Beispiel #3
0
        verbose = 10,
        n_jobs = 2,
        n_points = 2,
        scoring = 'accuracy',
    )

    checkpoint_callback = skopt.callbacks.CheckpointSaver(f'D:\\FINKI\\8_dps\\Project\\MODELS\\skopt_checkpoints\\{EXPERIMENT_ID}.pkl')
    hyperparameters_optimizer.fit(X_train, y_train, callback = [checkpoint_callback])
    skopt.dump(hyperparameters_optimizer, f'saved_models\\{EXPERIMENT_ID}.pkl')

    y_pred = hyperparameters_optimizer.best_estimator_.predict(X_test)

    for i in range(len(hyperparameters_optimizer.cv_results_['params'])):
        exp = OfflineExperiment(
            api_key = 'A8Lg71j9LtIrsv0deBA0DVGcR',
            project_name = ALGORITHM,
            workspace = "8_dps",
            auto_output_logging = 'native',
            offline_directory = f'D:\\FINKI\\8_dps\\Project\\MODELS\\comet_ml_offline_experiments\\{EXPERIMENT_ID}'
        )
        exp.set_name(f'{EXPERIMENT_ID}_{i + 1}')
        exp.add_tags([DS, SEGMENTS_LENGTH, ])
        for k, v in hyperparameters_optimizer.cv_results_.items():
            if k == "params": exp.log_parameters(dict(v[i]))
            else: exp.log_metric(k, v[i])
        exp.end()

        
        
        
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=hp.num_workers)

    speaker_encoder = None
    if hp.speaker_encoder_path != "":
        speaker_encoder = load_speaker_encoder(Path(hp.speaker_encoder_path),
                                               device).to(device)
        for param in speaker_encoder.parameters():
            param.requires_grad = False
        else:
            speaker_encoder.train()

    # Define model
    fastspeech_model = FastSpeech2(speaker_encoder).to(device)
    model = nn.DataParallel(fastspeech_model).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of FastSpeech2 Parameters:', num_param)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = FastSpeech2Loss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path)
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    if hp.vocoder == 'melgan':
        vocoder = utils.get_melgan()
        vocoder_infer = utils.melgan_infer
    elif hp.vocoder == 'waveglow':
        vocoder = utils.get_waveglow()
        vocoder_infer = utils.waveglow_infer
    else:
        raise ValueError("Vocoder '%s' is not supported", hp.vocoder)

    comet_experiment = None
    use_comet = int(os.getenv("USE_COMET", default=0))
    if use_comet != 0:
        if use_comet == 1:
            offline_dir = os.path.join(hp.models_path, "comet")
            os.makedirs(offline_dir, exist_ok=True)
            comet_experiment = OfflineExperiment(
                project_name="mlp-project",
                workspace="ino-voice",
                offline_directory=offline_dir,
            )
        elif use_comet == 2:
            comet_experiment = Experiment(
                api_key="BtyTwUoagGMh3uN4VZt6gMOn8",
                project_name="mlp-project",
                workspace="ino-voice",
            )

        comet_experiment.set_name(args.experiment_name)
        comet_experiment.log_parameters(hp)
        comet_experiment.log_html(args.m)

    start_time = time.perf_counter()
    first_mel_train_loss, first_postnet_train_loss, first_d_train_loss, first_f_train_loss, first_e_train_loss = \
        None, None, None, None, None

    for epoch in range(hp.epochs):
        total_step = hp.epochs * len(loader) * hp.batch_size
        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                model = model.train()

                current_step = i * hp.batch_size + j + args.restore_step + epoch * len(
                    loader) * hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # text = torch.from_numpy(data_of_batch["text"]).long()
                # mel_target = torch.from_numpy(data_of_batch["mel_target"]).float()
                # D = torch.from_numpy(data_of_batch["D"]).long()
                # log_D = torch.from_numpy(data_of_batch["log_D"]).float()
                # f0 = torch.from_numpy(data_of_batch["f0"]).float()
                # energy = torch.from_numpy(data_of_batch["energy"]).float()
                # src_len = torch.from_numpy(data_of_batch["src_len"]).long()
                # mel_len = torch.from_numpy(data_of_batch["mel_len"]).long()
                # max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                # max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = \
                    model(text, src_len, mel_target, mel_len, D, f0, energy, max_src_len, max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)
                total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

                # Set initial values for scaling
                if first_mel_train_loss is None:
                    first_mel_train_loss = mel_loss
                    first_postnet_train_loss = mel_postnet_loss
                    first_d_train_loss = d_loss
                    first_f_train_loss = f_loss
                    first_e_train_loss = e_loss

                mel_l = mel_loss.item() / first_mel_train_loss
                mel_postnet_l = mel_postnet_loss.item(
                ) / first_postnet_train_loss
                d_l = d_loss.item() / first_d_train_loss
                f_l = f_loss.item() / first_f_train_loss
                e_l = e_loss.item() / first_e_train_loss

                # Logger
                if comet_experiment is not None:
                    comet_experiment.log_metric(
                        "total_loss", mel_l + mel_postnet_l + d_l + f_l + e_l,
                        current_step)
                    comet_experiment.log_metric("mel_loss", mel_l,
                                                current_step)
                    comet_experiment.log_metric("mel_postnet_loss",
                                                mel_postnet_l, current_step)
                    comet_experiment.log_metric("duration_loss", d_l,
                                                current_step)
                    comet_experiment.log_metric("f0_loss", f_l, current_step)
                    comet_experiment.log_metric("energy_loss", e_l,
                                                current_step)

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step % hp.log_step == 0:
                    now = time.perf_counter()

                    print("\nEpoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step))
                    print(
                        "Total Loss: {:.4f}, Mel Loss: {:.5f}, Mel PostNet Loss: {:.5f}, Duration Loss: {:.5f}, "
                        "F0 Loss: {:.5f}, Energy Loss: {:.5f};".format(
                            mel_l + mel_postnet_l + d_l + f_l + e_l, mel_l,
                            mel_postnet_l, d_l, f_l, e_l))
                    print("Time Used: {:.3f}s".format(now - start_time))
                    start_time = now

                if current_step % hp.checkpoint == 0:
                    file_path = os.path.join(
                        checkpoint_path,
                        'checkpoint_{}.pth.tar'.format(current_step))
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        }, file_path)
                    print("saving model at to {}".format(file_path))

                if current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)

                    if comet_experiment is not None:
                        comet_experiment.log_audio(
                            audiotools.inv_mel_spec(mel), hp.sampling_rate,
                            "step_{}_griffin_lim.wav".format(current_step))
                        comet_experiment.log_audio(
                            audiotools.inv_mel_spec(mel_postnet),
                            hp.sampling_rate,
                            "step_{}_postnet_griffin_lim.wav".format(
                                current_step))
                        comet_experiment.log_audio(
                            vocoder_infer(mel_torch,
                                          vocoder), hp.sampling_rate,
                            'step_{}_{}.wav'.format(current_step, hp.vocoder))
                        comet_experiment.log_audio(
                            vocoder_infer(mel_postnet_torch, vocoder),
                            hp.sampling_rate, 'step_{}_postnet_{}.wav'.format(
                                current_step, hp.vocoder))
                        comet_experiment.log_audio(
                            vocoder_infer(mel_target_torch,
                                          vocoder), hp.sampling_rate,
                            'step_{}_ground-truth_{}.wav'.format(
                                current_step, hp.vocoder))

                        f0 = f0[0, :length].detach().cpu().numpy()
                        energy = energy[0, :length].detach().cpu().numpy()
                        f0_output = f0_output[
                            0, :length].detach().cpu().numpy()
                        energy_output = energy_output[
                            0, :length].detach().cpu().numpy()

                        utils.plot_data(
                            [(mel_postnet.numpy(), f0_output, energy_output),
                             (mel_target.numpy(), f0, energy)],
                            comet_experiment, [
                                'Synthesized Spectrogram',
                                'Ground-Truth Spectrogram'
                            ])

                if current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        if comet_experiment is not None:
                            with comet_experiment.validate():
                                d_l, f_l, e_l, m_l, m_p_l = evaluate(
                                    model, current_step, comet_experiment)
                                t_l = d_l + f_l + e_l + m_l + m_p_l

                                comet_experiment.log_metric(
                                    "total_loss", t_l, current_step)
                                comet_experiment.log_metric(
                                    "mel_loss", m_l, current_step)
                                comet_experiment.log_metric(
                                    "mel_postnet_loss", m_p_l, current_step)
                                comet_experiment.log_metric(
                                    "duration_loss", d_l, current_step)
                                comet_experiment.log_metric(
                                    "F0_loss", f_l, current_step)
                                comet_experiment.log_metric(
                                    "energy_loss", e_l, current_step)
Beispiel #5
0
class AbstractTrainer(ABC):
    """Abstract class that fits the given model"""

    def __init__(
        self,
        model,
        optimizer,
        cfg,
        train_loader,
        valid_loader,
        test_loader,
        device,
        output_dir,
        hyper_params,
        max_patience=5,
    ):
        """
        :param model: pytorch model
        :param optimizer: pytorch optimizaer
        :param cfg: config instance
        :param train_loader: train data loader
        :param valid_loade: valid data laoder
        :param device: gpu device used (ex: cuda:0)
        :param output_dir: output directory where the model and the results
            will be located
        :param hyper_params: hyper parameters
        :param max_patience: max number of iteration without seeing
            improvement in accuracy
        """
        self.model = model.to(device)
        self.optimizer = optimizer
        self.cfg = cfg
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.device = device
        self.stats = StatsRecorder()
        self.performance_evaluator = PerformanceEvaluator(self.valid_loader)
        self.train_evaluator = PerformanceEvaluator(self.train_loader)
        self.valid_evaluator = PerformanceEvaluator(self.valid_loader)
        self.test_evaluator = PerformanceEvaluator(self.test_loader)
        self.output_dir = output_dir
        self.best_model = None
        self.hyper_params = hyper_params
        self.max_patience = max_patience
        self.current_patience = 0
        self.epoch = 0
        self.comet_ml_experiment = None
        self.last_checkpoint_filename = None

    def initialize_cometml_experiment(self, hyper_params):
        """
        Initialize the comet_ml experiment (only if enabled in config file)
        :param hyper_params: current hyper parameters dictionary
        :return:
        """
        if (
            self.comet_ml_experiment is None
            and self.cfg.COMET_ML_UPLOAD is True
        ):
            # Create an experiment
            self.comet_ml_experiment = Experiment(
                api_key=os.environ["COMET_API_KEY"],
                project_name="general",
                workspace="proguoram",
            )
            if self.comet_ml_experiment.disabled is True:
                # There is problably no internet (in the cluster for example)
                # So we create a offline experiment
                self.comet_ml_experiment = OfflineExperiment(
                    workspace="proguoram",
                    project_name="general",
                    offline_directory=self.output_dir,
                )
            self.comet_ml_experiment.log_parameters(hyper_params)

    def fit(self, current_hyper_params, hyper_param_search_state=None):
        """
        Fit function applied train, val, test to all models.
        Each train/val/test may differe for each model
        I/O is the same, so wrap them with this method and do logging
        """
        self.initialize_cometml_experiment(current_hyper_params)
        print("# Start training #")
        since = time.time()

        summary_writer = TBSummaryWriter(self.output_dir, current_hyper_params)

        for epoch in range(self.epoch, self.cfg.TRAIN.NUM_EPOCHS, 1):
            self.train(current_hyper_params)
            self.train_evaluator.evaluate(
                self.model, self.device, self.stats, mode="train")
            self.validate(self.model)
            self.epoch = epoch
            print(
                "\nEpoch: {}/{}".format(epoch + 1, self.cfg.TRAIN.NUM_EPOCHS)
            )
            self.stats.print_last_epoch_stats()
            summary_writer.add_stats(self.stats, epoch)
            if self.cfg.COMET_ML_UPLOAD is True:
                self.stats.upload_to_comet_ml(self.comet_ml_experiment, epoch)
            if self.early_stopping_check(self.model, hyper_param_search_state):
                break
        time_elapsed = time.time() - since
        self.add_plots_summary(summary_writer)
        print(
            "\n\nTraining complete in {:.0f}m {:.0f}s".format(
                time_elapsed // 60, time_elapsed % 60
            )
        )
        # manually uncomment in training
        # comment this out in hyper_param_train
        print("Force save after final epoch")
        model_filename = self.output_dir + "/checkpoint_{}_ep_{}.pth".format(
            self.stats.valid_best_accuracy, self.epoch
        )
        self.save_current_best_model(model_filename, hyper_param_search_state)

    @classmethod
    @abstractmethod
    def train(self, current_hyper_params):
        """
        Abstract method for the training
        :param current_hyper_params: current hyper parameters dictionary
        """
        pass

    def validate(self, model):
        """
        Validate the model
        :param model: pytorch model
        """
        self.valid_evaluator.evaluate(
            model, self.device, self.stats, mode="valid"
        )

    def test(self, model):
        """
        Test the model
        :param model: pytorch model
        """
        model = model.to(self.device)
        self.test_evaluator.evaluate(
            model, self.device, self.stats, mode="test"
        )

    def early_stopping_check(self, model, hyper_param_search_state=None):
        """
        Early stop check
        :param model: pytorch model
        :param current_hyper_params: current hyper parameters dictionary
        :return: True if need to stop. False if continue the training
        """
        last_accuracy_computed = self.stats.valid_accuracies[-1]
        if last_accuracy_computed > self.stats.valid_best_accuracy:
            self.stats.valid_best_accuracy = last_accuracy_computed
            self.best_model = copy.deepcopy(model)
            print("Checkpointing new model...")
            model_filename = self.output_dir + "/checkpoint_{}.pth".format(
                self.stats.valid_best_accuracy
            )
            self.save_current_best_model(
                model_filename, hyper_param_search_state
            )
            if self.last_checkpoint_filename is not None:
                os.remove(self.last_checkpoint_filename)
            self.last_checkpoint_filename = model_filename
            self.current_patience = 0
        else:
            self.current_patience += 1
            if self.current_patience > self.max_patience:
                return True
        return False

    def compute_loss(
        self, length_logits, digits_logits, length_labels, digits_labels
    ):
        """
        Multi loss computing function
        :param length_logits: length logits tensor (N x 7)
        :param digits_logits: digits legits tensor (N x 5 x 10)
        :param length_labels: length labels (N x 5 x 1)
        :param digits_labels: length labels tensor (N x 1)
        :return: loss tensor value
        """
        loss = torch.nn.functional.cross_entropy(length_logits, length_labels)
        for i in range(digits_labels.shape[1]):
            loss = loss + torch.nn.functional.cross_entropy(
                digits_logits[i], digits_labels[:, i], ignore_index=-1
            )
        return loss

    def load_state_dict(self, state_dict):
        """
        Loads the previous state of the trainer
        Should be overriden in the children classes if needed (see LRSchedulerTrainer for an example)
        :param state_dict: state dictionary
        """
        self.epoch = state_dict["epoch"]
        self.stats = state_dict["stats"]
        self.current_patience = state_dict["current_patience"]
        self.best_model = self.model

    def get_state_dict(self, hyper_param_search_state=None):
        """
         Gets the current state of the trainer
         Should be overriden in the children classes if needed (see LRSchedulerTrainer for an example)
         :param hyper_param_search_state: hyper param search state if we are doing an hyper params serach
         (None by default)
        :return state_dict
         """
        seed = np.random.get_state()[1][0]
        return {
            "epoch": self.epoch + 1,
            "model_state_dict": self.best_model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "stats": self.stats,
            "seed": seed,
            "current_patience": self.current_patience,
            "hyper_param_search_state": hyper_param_search_state,
            "hyper_params": self.hyper_params,
        }

    def save_current_best_model(self, out_path, hyper_param_search_state=None):
        """
        Saves the current best model
        :param out_path: output path string
        :param hyper_param_search_state: hyper param search state if we are doing an hyper params serach
        (None by default)
        """
        state_dict = self.get_state_dict(hyper_param_search_state)
        torch.save(state_dict, out_path)
        print("Model saved!")

    def _train_batch(self, batch):
        """
        Basic batch method (called by children class normally
        :param batch: batch
        :return: loss tensor value
        """
        inputs, targets = batch["image"], batch["target"]

        inputs = inputs.to(self.device)
        targets = targets.long().to(self.device)
        target_ndigits = targets[:, 0]

        # Zero the gradient buffer
        self.optimizer.zero_grad()

        # Forward
        pred_length, pred_sequences = self.model(inputs)

        # For each digit predicted
        target_digit = targets[:, 1:]
        loss = self.compute_loss(
            pred_length, pred_sequences, target_ndigits, target_digit
        )
        # Backward
        loss.backward()
        # Optimize
        self.optimizer.step()

        return loss

    def add_plots_summary(self, summary_writer):
        """
        Add plotting values for tensor board
        :param summary_writer: Summary writer object from tensor board
        """
        # plot loss curves
        loss_dict = {
            "Train loss": self.stats.train_loss_history,
            "Valid loss": self.stats.valid_losses,
        }
        axis_labels = {"x": "Epochs", "y": "Loss"}
        summary_writer.plot_curves(loss_dict, "Learning curves", axis_labels)

        # plot accuracy curves
        acc_dict = {
            "Valid accuracy": self.stats.valid_accuracies,
            "Length accuracy": self.stats.length_accuracy,
        }
        axis_labels = {"x": "Epochs", "y": "Accuracy"}
        summary_writer.plot_curves(acc_dict, "Accuracy curves", axis_labels)
Beispiel #6
0
class Logger:
    def __init__(self, send_logs, tags, parameters, experiment=None):
        self.stations = 5
        self.send_logs = send_logs
        if self.send_logs:
            if experiment is None:
                json_loc = glob.glob("./**/comet_token.json")[0]
                with open(json_loc, "r") as f:
                    kwargs = json.load(f)

                self.experiment = OfflineExperiment(**kwargs)
            else:
                self.experiment = experiment
        self.sent_mb = 0
        self.speed_window = deque(maxlen=100)
        self.step_time = None
        self.current_speed = 0
        if self.send_logs:
            if tags is not None:
                self.experiment.add_tags(tags)
            if parameters is not None:
                self.experiment.log_parameters(parameters)

    def begin_logging(self, episode_count, steps_per_ep, sigma, theta, step_time):
        self.step_time = step_time
        if self.send_logs:
            self.experiment.log_parameter("Episode count", episode_count)
            self.experiment.log_parameter("Steps per episode", steps_per_ep)
            self.experiment.log_parameter("theta", theta)
            self.experiment.log_parameter("sigma", sigma)

    def log_round(self, states, reward, cumulative_reward, info, loss, observations, step):
        self.experiment.log_histogram_3d(states, name="Observations", step=step)
        info = [[j for j in i.split("|")] for i in info]
        info = np.mean(np.array(info, dtype=np.float32), axis=0)
        try:
            round_mb = info[0]
        except Exception as e:
            print(info)
            print(reward)
            raise e
        self.speed_window.append(round_mb)
        self.current_speed = np.mean(np.asarray(self.speed_window)/self.step_time)
        self.sent_mb += round_mb
        CW = info[1]
        CW_ax = info[2]
        self.stations = info[3]
        fairness = info[4]

        if self.send_logs:
            self.experiment.log_metric("Round reward", np.mean(reward), step=step)
            self.experiment.log_metric("Per-ep reward", np.mean(cumulative_reward), step=step)
            self.experiment.log_metric("Megabytes sent", self.sent_mb, step=step)
            self.experiment.log_metric("Round megabytes sent", round_mb, step=step)
            self.experiment.log_metric("Chosen CW for legacy devices", CW, step=step)
            self.experiment.log_metric("Chosen CW for 802.11ax devices", CW_ax, step=step)
            self.experiment.log_metric("Station count", self.stations, step=step)
            self.experiment.log_metric("Current throughput", self.current_speed, step=step)
            self.experiment.log_metric("Fairness index", fairness, step=step)

            for i, obs in enumerate(observations):
                self.experiment.log_metric(f"Observation {i}", obs, step=step)

            self.experiment.log_metrics(loss, step=step)

    def log_episode(self, cumulative_reward, speed, step):
        if self.send_logs:
            self.experiment.log_metric("Cumulative reward", cumulative_reward, step=step)
            self.experiment.log_metric("Speed", speed, step=step)

        self.sent_mb = 0
        self.last_speed = speed
        self.speed_window = deque(maxlen=100)
        self.current_speed = 0

    def end(self):
        if self.send_logs:
            self.experiment.end()
Beispiel #7
0
                                       parse_args=False,
                                       project_name='swissroll-' + args.tag,
                                       workspace="wronnyhuang")
    else:
        experiment = Experiment(api_key="vPCPPZrcrUBitgoQkvzxdsh9k",
                                parse_args=False,
                                project_name='swissroll-' + args.tag,
                                workspace="wronnyhuang")

    open(join(logdir, 'comet_expt_key.txt'), 'w+').write(experiment.get_key())
    if any([a.find('nhidden1') != -1 for a in sys.argv[1:]]):
        args.nhidden = [
            args.nhidden1, args.nhidden2, args.nhidden3, args.nhidden4,
            args.nhidden5, args.nhidden6
        ]
    experiment.log_parameters(vars(args))
    experiment.set_name(args.sugg)
    print(sys.argv)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    np.random.seed(args.seed)
    tf.set_random_seed(args.seed)

    # make dataset
    X, y = twospirals(args.ndata // 2, noise=args.noise)
    order = np.random.permutation(len(X))
    X = X[order]
    y = y[order]
    splitIdx = int(.5 * len(X))
    xtrain, ytrain = X[:splitIdx], y[:splitIdx, None]
    xtest, ytest = X[splitIdx:], y[splitIdx:, None]
Beispiel #8
0
            print(e)
            print("Ignoring argument", u)

    for o in dir(opts):
        if not o.startswith("_"):
            if o in config:
                print("Overwriting {:20} {:30} -> {:}".format(
                    o, config[k], getattr(opts, o)))
                config[o] = getattr(opts, o)

comet_exp.log_asset(opts.config)
max_iter = config["max_iter"]
display_size = config["display_size"]
config["vgg_model_path"] = opts.output_path

comet_exp.log_parameters(config)

print("Using model", opts.trainer)
# Setup model and data loader
if opts.trainer == "MUNIT":
    trainer = MUNIT_Trainer(config, comet_exp)
elif opts.trainer == "UNIT":
    trainer = UNIT_Trainer(config)
elif opts.trainer == "DoubleMUNIT":
    trainer = DoubleMUNIT_Trainer(config, comet_exp)
else:
    sys.exit("Only support MUNIT|UNIT|DOubleMUNIT")
trainer.cuda()
train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
    config)
train_display_images_a = torch.stack(
Beispiel #9
0
    # Set all seeds for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # Set up Comet Experiment tracking
    experiment = OfflineExperiment("z15Um8oxWZwiXQXZxZKGh48cl",
                                   workspace='swechhachoudhary',
                                   offline_directory="../swechhas_experiments")

    experiment.set_name(
        name=args.config +
        "_dim={}_overlapped={}".format(latent_dim, train_split))
    experiment.log_parameters(configuration)

    if encoding_model == 'pca':
        encoding_model = PCAEncoder(seed)
        flattened = True
    elif encoding_model == 'vae':
        encoding_model = VAE(latent_dim=latent_dim).to(device)
        flattened = True
    elif encoding_model == "ae":
        encoding_model = AE(latent_dim=latent_dim).to(device)
        flattened = True
    elif encoding_model == "cae":
        encoding_model = CAE(latent_dim=latent_dim).to(device)
        flattened = False
    elif encoding_model == "cvae":
        encoding_model = CVAE(latent_dim=latent_dim).to(device)
Beispiel #10
0
class CometWriter:
    def __init__(self,
                 logger,
                 project_name: Optional[str] = None,
                 experiment_name: Optional[str] = None,
                 api_key: Optional[str] = None,
                 log_dir: Optional[str] = None,
                 offline: bool = False,
                 **kwargs):
        if not _COMET_AVAILABLE:
            raise ImportError(
                "You want to use `comet_ml` logger which is not installed yet,"
                " install it with `pip install comet-ml`.")

        self.project_name = project_name
        self.experiment_name = experiment_name
        self.kwargs = kwargs

        self.timer = Timer()

        if (api_key is not None) and (log_dir is not None):
            self.mode = "offline" if offline else "online"
            self.api_key = api_key
            self.log_dir = log_dir

        elif api_key is not None:
            self.mode = "online"
            self.api_key = api_key
            self.log_dir = None
        elif log_dir is not None:
            self.mode = "offline"
            self.log_dir = log_dir
        else:
            logger.warning(
                "CometLogger requires either api_key or save_dir during initialization."
            )

        if self.mode == "online":
            self.experiment = CometExperiment(
                api_key=self.api_key,
                project_name=self.project_name,
                **self.kwargs,
            )
        else:
            self.experiment = CometOfflineExperiment(
                offline_directory=self.log_dir,
                project_name=self.project_name,
                **self.kwargs,
            )

        if self.experiment_name:
            self.experiment.set_name(self.experiment_name)

    def set_step(self, step, epoch=None, mode='train') -> None:
        self.mode = mode
        self.step = step
        self.epoch = epoch
        if step == 0:
            self.timer.reset()
        else:
            duration = self.timer.check()
            self.add_scalar({'steps_per_sec': 1 / duration})

    def log_hyperparams(self, params: Dict[str, Any]) -> None:
        self.experiment.log_parameters(params)

    def log_code(self, file_name=None, folder='models/') -> None:
        self.experiment.log_code(file_name=file_name, folder=folder)

    def add_scalar(self,
                   metrics: Dict[str, Union[torch.Tensor, float]],
                   step: Optional[int] = None,
                   epoch: Optional[int] = None) -> None:
        metrics_renamed = {}
        for key, val in metrics.items():
            tag = '{}/{}'.format(key, self.mode)
            if is_tensor(val):
                metrics_renamed[tag] = val.cpu().detach()
            else:
                metrics_renamed[tag] = val
        if epoch is None:
            self.experiment.log_metrics(metrics_renamed,
                                        step=self.step,
                                        epoch=self.epoch)
        else:
            self.experiment.log_metrics(metrics_renamed, epoch=epoch)

    def add_plot(self, figure_name, figure):
        """
        Primarily for log gate plots
        """
        self.experiment.log_figure(figure_name=figure_name, figure=figure)

    def add_hist3d(self, hist, name):
        """
        Primarily for log gate plots
        """
        self.experiment.log_histogram_3d(hist, name=name)

    def reset_experiment(self):
        self.experiment = None

    def finalize(self) -> None:
        self.experiment.end()
        self.reset_experiment()
        comet_exp = OfflineExperiment(
            api_key="hIXq6lDzWzz24zgKv7RYz6blo",
            project_name="supercyclecons",
            workspace="cinjon",
            auto_metric_logging=True,
            auto_output_logging=None,
            auto_param_logging=False,
            offline_directory=params['local_comet_dir'])
    else:
        comet_exp = CometExperiment(api_key="hIXq6lDzWzz24zgKv7RYz6blo",
                                    project_name="supercyclecons",
                                    workspace="cinjon",
                                    auto_metric_logging=True,
                                    auto_output_logging=None,
                                    auto_param_logging=False)
    comet_exp.log_parameters(vars(args))
    comet_exp.set_name(params['name'])


def partial_load(pretrained_dict, model):
    model_dict = model.state_dict()

    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(pretrained_dict)
Beispiel #12
0
if __name__ == '__main__':
    SEED = 1234

    random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    config = get_config(sys.argv[1])
    # experiment = Experiment("wXwnV8LZOtVfxqnRxr65Lv7C2")
    comet_dir_path = os.path.join(config["result_directory"], config["model"])
    makedirs(comet_dir_path)
    experiment = OfflineExperiment(
        project_name="DeepGenomics",
        offline_directory=comet_dir_path)
    experiment.log_parameters(config)
    if torch.cuda.is_available():
        # torch.cuda.set_device(str(os.environ["CUDA_VISIBLE_DEVICES"]))
        device = torch.device('cuda:{}'.format(os.environ["CUDA_VISIBLE_DEVICES"]))
    else:
        device = torch.device('cpu')
    print(device)
    number_of_examples = len(get_filenames(os.path.join(config["data"], "x")))
    list_ids = [str(i) for i in range(number_of_examples)]
    random.shuffle(list_ids)
    t_ind, v_ind = round(number_of_examples * 0.7), round(number_of_examples * 0.9)
    train_indices, validation_indices, test_indices = list_ids[:t_ind], list_ids[t_ind:v_ind], list_ids[v_ind:]
    
    params = {'batch_size': config["training"]["batch_size"],
              'shuffle': config["training"]["shuffle"],
              'num_workers': config["training"]["num_workers"]}
Beispiel #13
0
class CometLogger(Logger):
    def __init__(
        self,
        batch_size: int,
        snapshot_dir: Optional[str] = None,
        snapshot_mode: str = "last",
        snapshot_gap: int = 1,
        exp_set: Optional[str] = None,
        use_print_exp: bool = False,
        saved_exp: Optional[str] = None,
        **kwargs,
    ):
        """
        :param kwargs: passed to comet's Experiment at init.
        """
        if use_print_exp:
            self.experiment = PrintExperiment()
        else:
            from comet_ml import Experiment, ExistingExperiment, OfflineExperiment

            if saved_exp:
                self.experiment = ExistingExperiment(
                    previous_experiment=saved_exp, **kwargs
                )
            else:
                try:
                    self.experiment = Experiment(**kwargs)
                except ValueError:  # no API key
                    log_dir = Path.home() / "logs"
                    log_dir.mkdir(exist_ok=True)
                    self.experiment = OfflineExperiment(offline_directory=str(log_dir))

        self.experiment.log_parameter("complete", False)
        if exp_set:
            self.experiment.log_parameter("exp_set", exp_set)
        if snapshot_dir:
            snapshot_dir = Path(snapshot_dir) / self.experiment.get_key()
        # log_traj_window (int): How many trajectories to hold in deque for computing performance statistics.
        self.log_traj_window = 100
        self._cum_metrics = {
            "n_unsafe_actions": 0,
            "constraint_used": 0,
            "cum_completed_trajs": 0,
            "logging_time": 0,
        }
        self._new_completed_trajs = 0
        self._last_step = 0
        self._start_time = self._last_time = time()
        self._last_snapshot_upload = 0
        self._snaphot_upload_time = 30 * 60

        super().__init__(batch_size, snapshot_dir, snapshot_mode, snapshot_gap)

    def log_fast(
        self,
        step: int,
        traj_infos: Sequence[Dict[str, float]],
        opt_info: Optional[Tuple[Sequence[float], ...]] = None,
        test: bool = False,
    ) -> None:
        if not traj_infos:
            return
        start = time()

        self._new_completed_trajs += len(traj_infos)
        self._cum_metrics["cum_completed_trajs"] += len(traj_infos)
        # TODO: do we need to support sum(t[k]) if key in k?
        # without that, this doesn't include anything from extra eval samplers
        for key in self._cum_metrics:
            if key == "cum_completed_trajs":
                continue
            self._cum_metrics[key] += sum(t.get(key, 0) for t in traj_infos)
        self._cum_metrics["logging_time"] += time() - start

    def log(
        self,
        step: int,
        traj_infos: Sequence[Dict[str, float]],
        opt_info: Optional[Tuple[Sequence[float], ...]] = None,
        test: bool = False,
    ):
        self.log_fast(step, traj_infos, opt_info, test)
        start = time()
        with (self.experiment.test() if test else nullcontext()):
            step *= self.batch_size
            if opt_info is not None:
                # grad norm is left on the GPU for some reason
                # https://github.com/astooke/rlpyt/issues/163
                self.experiment.log_metrics(
                    {
                        k: np.mean(v)
                        for k, v in opt_info._asdict().items()
                        if k != "gradNorm"
                    },
                    step=step,
                )

            if traj_infos:
                agg_vals = {}
                for key in traj_infos[0].keys():
                    if key in self._cum_metrics:
                        continue
                    agg_vals[key] = sum(t[key] for t in traj_infos) / len(traj_infos)
                self.experiment.log_metrics(agg_vals, step=step)

            if not test:
                now = time()
                self.experiment.log_metrics(
                    {
                        "new_completed_trajs": self._new_completed_trajs,
                        "steps_per_second": (step - self._last_step)
                        / (now - self._last_time),
                    },
                    step=step,
                )
                self._last_time = now
                self._last_step = step
                self._new_completed_trajs = 0

        self.experiment.log_metrics(self._cum_metrics, step=step)
        self._cum_metrics["logging_time"] += time() - start

    def log_metric(self, name, val):
        self.experiment.log_metric(name, val)

    def log_parameters(self, parameters):
        self.experiment.log_parameters(parameters)

    def log_config(self, config):
        self.experiment.log_parameter("config", json.dumps(convert_dict(config)))

    def upload_snapshot(self):
        if self.snapshot_dir:
            self.experiment.log_asset(self._previous_snapshot_fname)

    def save_itr_params(
        self, step: int, params: Dict[str, Any], metric: Optional[float] = None
    ) -> None:
        super().save_itr_params(step, params, metric)
        now = time()
        if now - self._last_snapshot_upload > self._snaphot_upload_time:
            self._last_snapshot_upload = now
            self.upload_snapshot()

    def shutdown(self, error: bool = False) -> None:
        if not error:
            self.upload_snapshot()
            self.experiment.log_parameter("complete", True)
        self.experiment.end()
Beispiel #14
0
def run_experiment_iter(i, experiment, train_iter, nExp, agent_list, env,
                        video, user_seed, experiment_name, log_params, debug,
                        project_name, sps, sps_es, **kwargs):
    """
    Function used to paralelize the run_experiment calculations.

    Parameters
    ----------
    i : int
        Index of the agent being trained.

    Raises
    ------
    NotImplementedError
        In case Comet is used, raises this error to signal where user intervention
        is required (namely to set the api_key and the workspace).

    Returns
    -------
    rewards : array
        An array with the cumulative rewards, where each column corresponds to
        an agent (random seed), and each row to a training iteration.
    arms : array
        An array with the number of agent arms, where each column corresponds
        to an agent (random seed), and each row to a training iteration.
    agent : Agent
        The trained agent.

    """
    if debug:
        start = time.time()
        print("Experiment {0} out of {1}...".format(i + 1, nExp))
    if not user_seed:
        seed = int.from_bytes(os.urandom(4), 'big')
    else:
        seed = user_seed

    if experiment_name:
        raise NotImplementedError(
            "Before using Comet, you need to come here and set your API key")
        experiment = Experiment(api_key=None,
                                project_name=project_name,
                                workspace=None,
                                display_summary=False,
                                offline_directory="offline")
        experiment.add_tag(experiment_name)
        experiment.set_name("{0}_{1}".format(experiment_name, i))
        # Sometimes adding the tag fails
        log_params["experiment_tag"] = experiment_name
        experiment.log_parameters(log_params)

    agent = agent_list[i]
    if sps_es:  # This one overrides sps
        rewards, arms, agent = run_sps_es_experiment(agent,
                                                     env,
                                                     train_iter,
                                                     seed=seed,
                                                     video=video,
                                                     experiment=experiment,
                                                     **kwargs)
    elif sps:
        rewards, arms, agent = run_sps_experiment(agent,
                                                  env,
                                                  train_iter,
                                                  seed=seed,
                                                  video=video,
                                                  experiment=experiment,
                                                  **kwargs)
    else:
        rewards, arms, agent = run_aql_experiment(agent,
                                                  env,
                                                  train_iter,
                                                  seed=seed,
                                                  video=video,
                                                  experiment=experiment,
                                                  **kwargs)
    agent_list[i] = agent

    if experiment:
        experiment.end()

    if debug:
        end = time.time()
        elapsed = end - start
        units = "secs"
        if elapsed > 3600:
            elapsed /= 3600
            units = "hours"
        elif elapsed > 60:
            elapsed /= 60
            units = "mins"
        print("Time elapsed: {0:.02f} {1}".format(elapsed, units))

    return rewards, arms, agent
Beispiel #15
0
    "sequence_length": 28,
    "input_size": 28,
    "hidden_size": 128,
    "num_layers": 2,
    "num_classes": 10,
    "batch_size": 100,
    "num_epochs": 3,
    "learning_rate": 0.01
}

optimizer = Optimizer("pA3Hqc1pEswNvXOPtSoRobt7C")

experiment = OfflineExperiment(project_name="horoma",
                               offline_directory="./experiments",
                               disabled=False)
experiment.log_parameters(hyper_params)

# MNIST Dataset
train_dataset = dsets.MNIST(root='./data/',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data/',
                           train=False,
                           transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=hyper_params['batch_size'], shuffle=True)
Beispiel #16
0
def main(args):
    print('Pretrain? ', not args.not_pretrain)
    print(args.model)
    start_time = time.time()

    if opt['local_comet_dir']:
        comet_exp = OfflineExperiment(api_key="hIXq6lDzWzz24zgKv7RYz6blo",
                                      project_name="selfcifar",
                                      workspace="cinjon",
                                      auto_metric_logging=True,
                                      auto_output_logging=None,
                                      auto_param_logging=False,
                                      offline_directory=opt['local_comet_dir'])
    else:
        comet_exp = CometExperiment(api_key="hIXq6lDzWzz24zgKv7RYz6blo",
                                    project_name="selfcifar",
                                    workspace="cinjon",
                                    auto_metric_logging=True,
                                    auto_output_logging=None,
                                    auto_param_logging=False)
    comet_exp.log_parameters(vars(args))
    comet_exp.set_name(args.name)

    # Build model
    # path = "/misc/kcgscratch1/ChoGroup/resnick/spaceofmotion/zeping/bsn"
    linear_cls = NonLinearModel if args.do_nonlinear else LinearModel

    if args.model == "amdim":
        hparams = load_hparams_from_tags_csv(
            '/checkpoint/cinjon/amdim/meta_tags.csv')
        # hparams = load_hparams_from_tags_csv(os.path.join(path, "meta_tags.csv"))
        model = AMDIMModel(hparams)
        if not args.not_pretrain:
            # _path = os.path.join(path, "_ckpt_epoch_434.ckpt")
            _path = '/checkpoint/cinjon/amdim/_ckpt_epoch_434.ckpt'
            model.load_state_dict(torch.load(_path)["state_dict"])
        else:
            print("AMDIM not loading checkpoint")  # Debug
        linear_model = linear_cls(AMDIM_OUTPUT_DIM, args.num_classes)
    elif args.model == "ccc":
        model = CCCModel(None)
        if not args.not_pretrain:
            # _path = os.path.join(path, "TimeCycleCkpt14.pth")
            _path = '/checkpoint/cinjon/spaceofmotion/bsn/TimeCycleCkpt14.pth'
            checkpoint = torch.load(_path)
            base_dict = {
                '.'.join(k.split('.')[1:]): v
                for k, v in list(checkpoint['state_dict'].items())
            }
            model.load_state_dict(base_dict)
        else:
            print("CCC not loading checkpoint")  # Debug
        linear_model = linaer_cls(CCC_OUTPUT_DIM,
                                  args.num_classes)  #.to(device)
    elif args.model == "corrflow":
        model = CORRFLOWModel(None)
        if not args.not_pretrain:
            _path = '/checkpoint/cinjon/spaceofmotion/supercons/corrflow.kineticsmodel.pth'
            # _path = os.path.join(path, "corrflow.kineticsmodel.pth")
            checkpoint = torch.load(_path)
            base_dict = {
                '.'.join(k.split('.')[1:]): v
                for k, v in list(checkpoint['state_dict'].items())
            }
            model.load_state_dict(base_dict)
        else:
            print("CorrFlow not loading checkpoing")  # Debug
        linear_model = linear_cls(CORRFLOW_OUTPUT_DIM, args.num_classes)
    elif args.model == "resnet":
        if not args.not_pretrain:
            resnet = torchvision.models.resnet50(pretrained=True)
        else:
            resnet = torchvision.models.resnet50(pretrained=False)
            print("ResNet not loading checkpoint")  # Debug
        modules = list(resnet.children())[:-1]
        model = nn.Sequential(*modules)
        linear_model = linear_cls(RESNET_OUTPUT_DIM, args.num_classes)
    else:
        raise Exception("model type has to be amdim, ccc, corrflow or resnet")

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model).to(device)
        linear_model = nn.DataParallel(linear_model).to(device)
    else:
        model = model.to(device)
        linear_model = linear_model.to(device)
    # model = model.to(device)
    # linear_model = linear_model.to(device)

    # Freeze model
    for p in model.parameters():
        p.requires_grad = False
    model.eval()

    if args.optimizer == "Adam":
        optimizer = optim.Adam(linear_model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
        print("Optimizer: Adam with weight decay: {}".format(
            args.weight_decay))
    elif args.optimizer == "SGD":
        optimizer = optim.SGD(linear_model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
        print("Optimizer: SGD with weight decay: {} momentum: {}".format(
            args.weight_decay, args.momentum))
    else:
        raise Exception("optimizer should be Adam or SGD")
    optimizer.zero_grad()

    # Set up log dir
    now = datetime.datetime.now()
    log_dir = '/checkpoint/cinjon/spaceofmotion/bsn/cifar-%d-weights/%s/%s' % (
        args.num_classes, args.model, args.name)
    # log_dir = "{}{:%Y%m%dT%H%M}".format(args.model, now)
    # log_dir = os.path.join("weights", log_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    print("Saving to {}".format(log_dir))

    batch_size = args.batch_size * torch.cuda.device_count()
    # CIFAR-10
    if args.num_classes == 10:
        data_path = ("/private/home/cinjon/cifar-data/cifar-10-batches-py")
        _train_dataset = CIFAR_dataset(glob(os.path.join(data_path, "data*")),
                                       args.num_classes, args.model, True)
        # _train_acc_dataset = CIFAR_dataset(
        #     glob(os.path.join(data_path, "data*")),
        #     args.num_classes,
        #     args.model,
        #     False)
        train_dataloader = data.DataLoader(_train_dataset,
                                           shuffle=True,
                                           batch_size=batch_size,
                                           num_workers=args.num_workers)
        # train_split = int(len(_train_dataset) * 0.8)
        # train_dev_split = int(len(_train_dataset) - train_split)
        # train_dataset, train_dev_dataset = data.random_split(
        #     _train_dataset, [train_split, train_dev_split])
        # train_acc_dataloader = data.DataLoader(
        #     train_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers)
        # train_dev_acc_dataloader = data.DataLoader(
        #     train_dev_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers)
        # train_dataset = data.Subset(_train_dataset, list(range(train_split)))
        # train_dataloader = data.DataLoader(
        #     train_dataset, shuffle=True, batch_size=batch_size, num_workers=args.num_workers)
        # train_acc_dataset = data.Subset(
        #     _train_acc_dataset, list(range(train_split)))
        # train_acc_dataloader = data.DataLoader(
        #     train_acc_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers)
        # train_dev_acc_dataset = data.Subset(
        #     _train_acc_dataset, list(range(train_split, len(_train_acc_dataset))))
        # train_dev_acc_dataloader = data.DataLoader(
        #     train_dev_acc_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers)

        _val_dataset = CIFAR_dataset([os.path.join(data_path, "test_batch")],
                                     args.num_classes, args.model, False)
        val_dataloader = data.DataLoader(_val_dataset,
                                         shuffle=False,
                                         batch_size=batch_size,
                                         num_workers=args.num_workers)
        # val_split = int(len(_val_dataset) * 0.8)
        # val_dev_split = int(len(_val_dataset) - val_split)
        # val_dataset, val_dev_dataset = data.random_split(
        #     _val_dataset, [val_split, val_dev_split])
        # val_dataloader = data.DataLoader(
        #     val_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers)
        # val_dev_dataloader = data.DataLoader(
        #     val_dev_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers)
    # CIFAR-100
    elif args.num_classes == 100:
        data_path = ("/private/home/cinjon/cifar-data/cifar-100-python")
        _train_dataset = CIFAR_dataset([os.path.join(data_path, "train")],
                                       args.num_classes, args.model, True)
        train_dataloader = data.DataLoader(_train_dataset,
                                           shuffle=True,
                                           batch_size=batch_size)
        _val_dataset = CIFAR_dataset([os.path.join(data_path, "test")],
                                     args.num_classes, args.model, False)
        val_dataloader = data.DataLoader(_val_dataset,
                                         shuffle=False,
                                         batch_size=batch_size)
    else:
        raise Exception("num_classes should be 10 or 100")

    best_acc = 0.0
    best_epoch = 0

    # Training
    for epoch in range(1, args.epochs + 1):
        current_lr = max(3e-4, args.lr *\
            math.pow(0.5, math.floor(epoch / args.lr_interval)))
        linear_model.train()
        if args.optimizer == "Adam":
            optimizer = optim.Adam(linear_model.parameters(),
                                   lr=current_lr,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == "SGD":
            optimizer = optim.SGD(
                linear_model.parameters(),
                lr=current_lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
            )

        ####################################################
        # Train
        t = time.time()
        train_acc = 0
        train_loss_sum = 0.0
        for iter, input in enumerate(train_dataloader):
            if time.time(
            ) - start_time > args.time * 3600 - 300 and comet_exp is not None:
                comet_exp.end()
                sys.exit(-1)

            imgs = input[0].to(device)
            if args.model != "resnet":
                imgs = imgs.unsqueeze(1)
            lbls = input[1].flatten().to(device)

            # output = model(imgs)
            # output = linear_model(output)
            output = linear_model(model(imgs))
            loss = F.cross_entropy(output, lbls)
            train_loss_sum += float(loss.data)
            train_acc += int(sum(torch.argmax(output, dim=1) == lbls))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # log_text = "train epoch {}/{}\titer {}/{} loss:{} {:.3f}s/iter"
            if iter % 1500 == 0:
                log_text = "train epoch {}/{}\titer {}/{} loss:{}"
                print(log_text.format(epoch, args.epochs, iter + 1,
                                      len(train_dataloader), loss.data,
                                      time.time() - t),
                      flush=False)
                t = time.time()

        train_acc /= len(_train_dataset)
        train_loss_sum /= len(train_dataloader)
        with comet_exp.train():
            comet_exp.log_metrics({
                'acc': train_acc,
                'loss': train_loss_sum
            },
                                  step=(epoch + 1) * len(train_dataloader),
                                  epoch=epoch + 1)
        print("train acc epoch {}/{} loss:{} train_acc:{}".format(
            epoch, args.epochs, train_loss_sum, train_acc),
              flush=True)

        #######################################################################
        # Train acc
        # linear_model.eval()
        # train_acc = 0
        # train_loss_sum = 0.0
        # for iter, input in enumerate(train_acc_dataloader):
        #     imgs = input[0].to(device)
        #     if args.model != "resnet":
        #         imgs = imgs.unsqueeze(1)
        #     lbls = input[1].flatten().to(device)
        #
        #     # output = model(imgs)
        #     # output = linear_model(output)
        #     output = linear_model(model(imgs))
        #     loss = F.cross_entropy(output, lbls)
        #     train_loss_sum += float(loss.data)
        #     train_acc += int(sum(torch.argmax(output, dim=1) == lbls))
        #
        #     print("train acc epoch {}/{}\titer {}/{} loss:{} {:.3f}s/iter".format(
        #         epoch,
        #         args.epochs,
        #         iter+1,
        #         len(train_acc_dataloader),
        #         loss.data,
        #         time.time() - t),
        #         flush=True)
        #     t = time.time()
        #
        #
        # train_acc /= len(train_acc_dataset)
        # train_loss_sum /= len(train_acc_dataloader)
        # print("train acc epoch {}/{} loss:{} train_acc:{}".format(
        #     epoch, args.epochs, train_loss_sum, train_acc), flush=True)

        #######################################################################
        # Train dev acc
        # # linear_model.eval()
        # train_dev_acc = 0
        # train_dev_loss_sum = 0.0
        # for iter, input in enumerate(train_dev_acc_dataloader):
        #     imgs = input[0].to(device)
        #     if args.model != "resnet":
        #         imgs = imgs.unsqueeze(1)
        #     lbls = input[1].flatten().to(device)
        #
        #     output = model(imgs)
        #     output = linear_model(output)
        #     # output = linear_model(model(imgs))
        #     loss = F.cross_entropy(output, lbls)
        #     train_dev_loss_sum += float(loss.data)
        #     train_dev_acc += int(sum(torch.argmax(output, dim=1) == lbls))
        #
        #     print("train dev acc epoch {}/{}\titer {}/{} loss:{} {:.3f}s/iter".format(
        #         epoch,
        #         args.epochs,
        #         iter+1,
        #         len(train_dev_acc_dataloader),
        #         loss.data,
        #         time.time() - t),
        #         flush=True)
        #     t = time.time()
        #
        # train_dev_acc /= len(train_dev_acc_dataset)
        # train_dev_loss_sum /= len(train_dev_acc_dataloader)
        # print("train dev epoch {}/{} loss:{} train_dev_acc:{}".format(
        #     epoch, args.epochs, train_dev_loss_sum, train_dev_acc), flush=True)

        #######################################################################
        # Val dev
        # # linear_model.eval()
        # val_dev_acc = 0
        # val_dev_loss_sum = 0.0
        # for iter, input in enumerate(val_dev_dataloader):
        #     imgs = input[0].to(device)
        #     if args.model != "resnet":
        #         imgs = imgs.unsqueeze(1)
        #     lbls = input[1].flatten().to(device)
        #
        #     output = model(imgs)
        #     output = linear_model(output)
        #     loss = F.cross_entropy(output, lbls)
        #     val_dev_loss_sum += float(loss.data)
        #     val_dev_acc += int(sum(torch.argmax(output, dim=1) == lbls))
        #
        #     print("val dev epoch {}/{} iter {}/{} loss:{} {:.3f}s/iter".format(
        #         epoch,
        #         args.epochs,
        #         iter+1,
        #         len(val_dev_dataloader),
        #         loss.data,
        #         time.time() - t),
        #         flush=True)
        #     t = time.time()
        #
        # val_dev_acc /= len(val_dev_dataset)
        # val_dev_loss_sum /= len(val_dev_dataloader)
        # print("val dev epoch {}/{} loss:{} val_dev_acc:{}".format(
        #     epoch, args.epochs, val_dev_loss_sum, val_dev_acc), flush=True)

        #######################################################################
        # Val
        linear_model.eval()
        val_acc = 0
        val_loss_sum = 0.0
        for iter, input in enumerate(val_dataloader):
            if time.time(
            ) - start_time > args.time * 3600 - 300 and comet_exp is not None:
                comet_exp.end()
                sys.exit(-1)

            imgs = input[0].to(device)
            if args.model != "resnet":
                imgs = imgs.unsqueeze(1)
            lbls = input[1].flatten().to(device)

            output = model(imgs)
            output = linear_model(output)
            loss = F.cross_entropy(output, lbls)
            val_loss_sum += float(loss.data)
            val_acc += int(sum(torch.argmax(output, dim=1) == lbls))

            # log_text = "val epoch {}/{} iter {}/{} loss:{} {:.3f}s/iter"
            if iter % 1500 == 0:
                log_text = "val epoch {}/{} iter {}/{} loss:{}"
                print(log_text.format(epoch, args.epochs, iter + 1,
                                      len(val_dataloader), loss.data,
                                      time.time() - t),
                      flush=False)
                t = time.time()

        val_acc /= len(_val_dataset)
        val_loss_sum /= len(val_dataloader)
        print("val epoch {}/{} loss:{} val_acc:{}".format(
            epoch, args.epochs, val_loss_sum, val_acc))
        with comet_exp.test():
            comet_exp.log_metrics({
                'acc': val_acc,
                'loss': val_loss_sum
            },
                                  step=(epoch + 1) * len(train_dataloader),
                                  epoch=epoch + 1)

        if val_acc > best_acc:
            best_acc = val_acc
            best_epoch = epoch
            linear_save_path = os.path.join(log_dir,
                                            "{}.linear.pth".format(epoch))
            model_save_path = os.path.join(log_dir,
                                           "{}.model.pth".format(epoch))
            torch.save(linear_model.state_dict(), linear_save_path)
            torch.save(model.state_dict(), model_save_path)

        # Check bias and variance
        print(
            "Epoch {} lr {} total: train_loss:{} train_acc:{} val_loss:{} val_acc:{}"
            .format(epoch, current_lr, train_loss_sum, train_acc, val_loss_sum,
                    val_acc),
            flush=True)
        # print("Epoch {} lr {} total: train_acc:{} train_dev_acc:{} val_dev_acc:{} val_acc:{}".format(
        #     epoch, current_lr, train_acc, train_dev_acc, val_dev_acc, val_acc), flush=True)

    print("The best epoch: {} acc: {}".format(best_epoch, best_acc))