Beispiel #1
0
class Experiment():
    def __init__(self, api_key=None, **kwargs):
        self._exp = None
        self._id = uuid4().hex
        if api_key:
            self._exp = CometExperiment(api_key,
                                        log_code=False,
                                        auto_param_logging=False,
                                        auto_metric_logging=False,
                                        **kwargs)
            self._id = self._exp.get_key()

    def log_metric(self, name, value, step=None, epoch=None):
        if self._exp:
            self._exp.log_metric(name, value, step, epoch)

    def log_epoch_end(self, epoch_cnt, step=None):
        if self._exp:
            self._exp.log_epoch_end(epoch_cnt, step=step)

    def log_parameters(self, hp):
        if self._exp:
            self._exp.log_parameters(flatten(hp, reducer='underscore'))

    @property
    def id(self):
        return self._id[:12]
Beispiel #2
0
def train():
    x_lines = [
        *toolz.take(LIMIT,
                    open('data/x.txt').read().lower().split('\n'))
    ]
    y_lines = [
        *toolz.take(LIMIT,
                    open('data/y.txt').read().lower().split('\n'))
    ]

    encoder = encoder_for_lines(S2S_PARAMS, x_lines + y_lines)

    try:
        start_idx = encoder.word_vocab[S2S_PARAMS.start_token]
        pad_idx = encoder.word_vocab[PAD_TOKEN]
    except AttributeError:
        start_idx = int(encoder.vocabulary_[S2S_PARAMS.start_token])
        pad_idx = encoder.vocabulary_[PAD_TOKEN]

    reverse_enc = {idx: word for word, idx in encoder.vocabulary_.items()}
    model = build_model(S2S_PARAMS, start_idx, pad_idx)

    x = encode_data(encoder, x_lines, is_input=True)
    y = encode_data(encoder, y_lines, is_input=False)

    print(x.shape, y.shape)

    x = x[:S2S_PARAMS.batch_size * int(len(x) / S2S_PARAMS.batch_size)]
    y = y[:S2S_PARAMS.batch_size * int(len(y) / S2S_PARAMS.batch_size)]

    test_x = x[:S2S_PARAMS.batch_size]
    losses = []

    if USE_COMET:
        experiment = Experiment(api_key="DQqhNiimkjP0gK6c8iGz9orzL",
                                log_code=True)
        experiment.log_multiple_params(S2S_PARAMS._asdict())
        for idx in range(1000):
            print("Shuffling data...")
            random_idx = random.sample([*range(len(x))], len(x))
            x = x[random_idx]
            y = y[random_idx]
            print("Training in epoch " + str(idx))
            losses.append(model.train_epoch(x, y, experiment=experiment))
            experiment.log_epoch_end(idx)
            print('Loss history: {}'.format(', '.join(
                ['{:.4f}'.format(loss) for loss in losses])))
            test_y = model.predict(test_x)
            for i in range(min([3, S2S_PARAMS.batch_size])):
                print('> ' + ' '.join(
                    reverse_enc.get(idx, '<unk/>') for idx in list(test_y[i])))
    else:
        for idx in range(1000):
            print("Training in epoch " + str(idx))
            model.train_epoch(x, y)
Beispiel #3
0
class CometMLMonitor(MonitorBase):
    """
    Send data to https://www.comet.ml.

    Note:
        1. comet_ml requires you to `import comet_ml` before importing tensorflow or tensorpack.
        2. The "automatic output logging" feature of comet_ml will make the training progress bar appear to freeze.
           Therefore the feature is disabled by default.
    """
    def __init__(self, experiment=None, api_key=None, tags=None, **kwargs):
        """
        Args:
            experiment (comet_ml.Experiment): if provided, invalidate all other arguments
            api_key (str): your comet.ml API key
            tags (list[str]): experiment tags
            kwargs: other arguments passed to :class:`comet_ml.Experiment`.
        """
        if experiment is not None:
            self._exp = experiment
            assert api_key is None and tags is None and len(kwargs) == 0
        else:
            from comet_ml import Experiment
            kwargs.setdefault(
                'log_code', True
            )  # though it's not functioning, git patch logging requires it
            kwargs.setdefault('auto_output_logging', None)
            self._exp = Experiment(api_key=api_key, **kwargs)
            if tags is not None:
                self._exp.add_tags(tags)

        self._exp.set_code(
            "Code logging is impossible because there are too many files ...")
        self._exp.log_dependency('tensorpack', __git_version__)

    @property
    def experiment(self):
        """
        The :class:`comet_ml.Experiment` instance.
        """
        return self._exp

    def _before_train(self):
        self._exp.set_model_graph(tf.get_default_graph())

    @HIDE_DOC
    def process_scalar(self, name, val):
        self._exp.log_metric(name, val, step=self.global_step)

    def _after_train(self):
        self._exp.end()

    def _after_epoch(self):
        self._exp.log_epoch_end(self.epoch_num)
Beispiel #4
0
class CometMLMonitor(MonitorBase):
    """
    Send scalar data and the graph to https://www.comet.ml.

    Note:
        1. comet_ml requires you to `import comet_ml` before importing tensorflow or tensorpack.
        2. The "automatic output logging" feature of comet_ml will make the training progress bar appear to freeze.
           Therefore the feature is disabled by default.
    """
    def __init__(self, experiment=None, tags=None, **kwargs):
        """
        Args:
            experiment (comet_ml.Experiment): if provided, invalidate all other arguments
            tags (list[str]): experiment tags
            kwargs: arguments used to initialize :class:`comet_ml.Experiment`,
                such as project name, API key, etc.
                Refer to its documentation for details.
        """
        if experiment is not None:
            self._exp = experiment
            assert tags is None and len(kwargs) == 0
        else:
            from comet_ml import Experiment
            kwargs.setdefault(
                'log_code', True
            )  # though it's not functioning, git patch logging requires it
            kwargs.setdefault('auto_output_logging', None)
            self._exp = Experiment(**kwargs)
            if tags is not None:
                self._exp.add_tags(tags)

        self._exp.set_code("Code logging is impossible ...")
        self._exp.log_dependency('tensorpack', __git_version__)

    @property
    def experiment(self):
        """
        The :class:`comet_ml.Experiment` instance.
        """
        return self._exp

    def _before_train(self):
        self._exp.set_model_graph(tf.get_default_graph())

    @HIDE_DOC
    def process_scalar(self, name, val):
        self._exp.log_metric(name, val, step=self.global_step)

    @HIDE_DOC
    def process_image(self, name, val):
        self._exp.set_step(self.global_step)
        for idx, v in enumerate(val):
            log_name = "{}_step{}{}".format(
                name, self.global_step, "_" + str(idx) if len(val) > 1 else "")

            self._exp.log_image(v,
                                image_format="jpeg",
                                name=log_name,
                                image_minmax=(0, 255))

    def _after_train(self):
        self._exp.end()

    def _after_epoch(self):
        self._exp.log_epoch_end(self.epoch_num)
Beispiel #5
0
    def train(self, train_data, cometml_key=None):
        if cometml_key is not None:
            experiment = Experiment(api_key=cometml_key,
                                    project_name="dsgym-tgan",
                                    workspace="baukebrenninkmeijer")
            experiment.log_parameter('batch_size', self.batch_size)
            experiment.log_parameter('embeddingDim', self.embeddingDim)
            experiment.log_parameter('genDim', self.genDim)
            experiment.log_parameter('disDim', self.disDim)
            experiment.log_parameter('GAN version', 'TGAN')

        # writer = SummaryWriter()
        # train_data = monkey_with_train_data(train_data)
        print('Transforming data...')
        self.transformer = BGMTransformer(self.meta)
        self.transformer.fit(train_data)
        pickle.dump(self.transformer,
                    open(f'{self.working_dir}/transformer.pkl', 'wb'))
        train_data = self.transformer.transform(train_data)
        # ncp1 = sum(self.transformer.components[0])
        # ncp2 = sum(self.transformer.components[1])
        # for i in range(ncp1):
        #     for j in range(ncp2):
        #         cond1 = train_data[:, 1 + i] > 0
        #         cond2 = train_data[:, 2 + ncp1 + j]
        #         cond = np.logical_and(cond1, cond2)
        #
        #         mean1 = train_data[cond, 0].mean()
        #         mean2 = train_data[cond, 1 + ncp1].mean()
        #
        #         std1 = train_data[cond, 0].std()
        #         std2 = train_data[cond, 1 + ncp1].std()
        #         print(i, j, np.sum(cond), mean1, std1, mean2, std2, sep='\t')

        # dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self.device))
        # loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        data_sampler = Sampler(train_data, self.transformer.output_info)

        data_dim = self.transformer.output_dim
        self.cond_generator = Cond(train_data, self.transformer.output_info)

        self.generator = Generator(
            self.embeddingDim + self.cond_generator.n_opt, self.genDim,
            data_dim).to(self.device)
        self.discriminator = Discriminator(
            data_dim + self.cond_generator.n_opt, self.disDim).to(self.device)

        optimizerG = optim.Adam(self.generator.parameters(),
                                lr=2e-4,
                                betas=(0.5, 0.9),
                                weight_decay=self.l2scale)
        optimizerD = optim.Adam(self.discriminator.parameters(),
                                lr=2e-4,
                                betas=(0.5,
                                       0.9))  #, weight_decay=self.l2scale)
        # pickle.dump(self, open(f'{self.working_dir}/tgan_synthesizer.pkl', 'wb'))
        # writer.add_graph(self.generator)

        max_epoch = max(self.store_epoch)
        assert self.batch_size % 2 == 0
        mean = torch.zeros(self.batch_size,
                           self.embeddingDim,
                           device=self.device)
        std = mean + 1

        print('Starting training loop...')
        steps_per_epoch = len(train_data) // self.batch_size
        for i in tqdm(range(max_epoch)):
            for id_ in tqdm(range(steps_per_epoch), leave=False):
                fakez = torch.normal(mean=mean, std=std)

                condvec = self.cond_generator.generate(self.batch_size)
                if condvec is None:
                    c1, m1, col, opt = None, None, None, None
                    real = data_sampler.sample(self.batch_size, col, opt)
                else:
                    c1, m1, col, opt = condvec
                    c1 = torch.from_numpy(c1).to(self.device)
                    m1 = torch.from_numpy(m1).to(self.device)
                    fakez = torch.cat([fakez, c1], dim=1)

                    perm = np.arange(self.batch_size)
                    np.random.shuffle(perm)
                    real = data_sampler.sample(self.batch_size, col[perm],
                                               opt[perm])
                    c2 = c1[perm]

                fake = self.generator(fakez)
                fakeact = apply_activate(fake, self.transformer.output_info)

                real = torch.from_numpy(real.astype('float32')).to(self.device)

                if c1 is not None:
                    fake_cat = torch.cat([fakeact, c1], dim=1)
                    real_cat = torch.cat([real, c2], dim=1)
                else:
                    real_cat = real
                    fake_cat = fake

                # print(real_cat[0])
                # print(fake_cat[0])
                # assert 0

                y_fake = self.discriminator(fake_cat)
                y_real = self.discriminator(real_cat)

                # loss_d = -(torch.log(torch.sigmoid(y_real) + 1e-4).mean()) - (torch.log(1. - torch.sigmoid(y_fake) + 1e-4).mean())
                loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
                pen = calc_gradient_penalty(self.discriminator, real_cat,
                                            fake_cat, self.device)

                optimizerD.zero_grad()
                pen.backward(retain_graph=True)
                loss_d.backward()
                optimizerD.step()

                # for p in discriminator.parameters():
                # p.data.clamp_(-0.05, 0.05)

                fakez = torch.normal(mean=mean, std=std)

                condvec = self.cond_generator.generate(self.batch_size)
                if condvec is None:
                    c1, m1, col, opt = None, None, None, None
                else:
                    c1, m1, col, opt = condvec
                    c1 = torch.from_numpy(c1).to(self.device)
                    m1 = torch.from_numpy(m1).to(self.device)
                    fakez = torch.cat([fakez, c1], dim=1)
                fake = self.generator(fakez)

                fakeact = apply_activate(fake, self.transformer.output_info)
                if c1 is not None:
                    y_fake = self.discriminator(torch.cat([fakeact, c1],
                                                          dim=1))
                else:
                    y_fake = self.discriminator(fakeact)

                if condvec is None:
                    cross_entropy = 0
                else:
                    cross_entropy = cond_loss(fake,
                                              self.transformer.output_info, c1,
                                              m1)
                # loss_g = -torch.log(torch.sigmoid(y_fake) + 1e-4).mean() + cross_entropy
                loss_g = -torch.mean(y_fake) + cross_entropy

                optimizerG.zero_grad()
                loss_g.backward()
                optimizerG.step()
                if cometml_key:
                    experiment.log_metric('Discriminator Loss', loss_d)
                    experiment.log_metric('Generator Loss', loss_g)

            # print("---")
            # print(fakeact[:, 0].mean(), fakeact[:, 0].std())
            # print(fakeact[:, 1 + ncp1].mean(), fakeact[:, 1 + ncp1].std())
            print(i + 1, loss_d.data, pen.data, loss_g.data, cross_entropy)
            if cometml_key:
                experiment.log_epoch_end(i)
            if i + 1 in self.store_epoch:
                print('Saving model')
                torch.save(
                    {
                        "generator": self.generator.state_dict(),
                        "discriminator": self.discriminator.state_dict(),
                    }, "{}/model_{}.tar".format(self.working_dir, i + 1))
        if cometml_key is not None:
            experiment.end()
Beispiel #6
0
def main(_):
    experiment = Experiment(api_key="xXtJguCo8yFdU7dpjEpo6YbHw",
                            project_name=args.experiment_name)
    hyper_params = {
        "learning_rate": args.lr,
        "num_epochs": args.max_epoch,
        "batch_size": args.single_batch_size,
        "alpha": args.alpha,
        "beta": args.beta,
        "gamma": args.gamma,
        "loss": args.loss
    }
    experiment.log_multiple_params(hyper_params)

    # TODO: split file support
    with tf.Graph().as_default():
        global save_model_dir
        start_epoch = 0
        global_counter = 0

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
            log_device_placement=False,
        )
        with tf.Session(config=config) as sess:
            # sess=tf_debug.LocalCLIDebugWrapperSession(sess,ui_type='readline')
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          learning_rate=args.lr,
                          max_gradient_norm=5.0,
                          alpha=args.alpha,
                          beta=args.beta,
                          gamma=args.gamma,
                          loss_type=args.loss,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))
            # param init/restore
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))
                start_epoch = model.epoch.eval() + 1
                global_counter = model.global_step.eval() + 1
            else:
                print("Created model with fresh parameters.")
                tf.global_variables_initializer().run()

            # train and validate
            is_summary, is_summary_image, is_validate = False, False, False

            summary_interval = 5
            summary_val_interval = 10
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            experiment.set_model_graph(sess.graph)

            # training
            with experiment.train():
                for epoch in range(start_epoch, args.max_epoch):
                    counter = 0
                    batch_time = time.time()
                    experiment.log_current_epoch(epoch)

                    for batch in iterate_data(
                            train_dir,
                            shuffle=True,
                            aug=True,
                            is_testset=False,
                            batch_size=args.single_batch_size *
                            cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT):

                        counter += 1
                        global_counter += 1
                        experiment.set_step(global_counter)
                        if counter % summary_interval == 0:
                            is_summary = True
                        else:
                            is_summary = False
                        epochs = args.max_epoch
                        start_time = time.time()
                        ret = model.train_step(sess,
                                               batch,
                                               train=True,
                                               summary=is_summary)
                        forward_time = time.time() - start_time
                        batch_time = time.time() - batch_time
                        param = ret
                        params = {
                            "loss": param[0],
                            "cls_loss": param[1],
                            "cls_pos_loss": param[2],
                            "cls_neg_loss": param[3]
                        }
                        experiment.log_multiple_metrics(params)
                        # print(ret)
                        print(
                            'train: {} @ epoch:{}/{} loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f}'
                            .format(counter, epoch, epochs, ret[0], ret[1],
                                    ret[2], ret[3], forward_time, batch_time))
                        # with open('log/train.txt', 'a') as f:
                        # f.write( 'train: {} @ epoch:{}/{} loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f}'.format(counter,epoch, epochs, ret[0], ret[1], ret[2], ret[3], forward_time, batch_time))

                        #print(counter, summary_interval, counter % summary_interval)
                        if counter % summary_interval == 0:
                            print("summary_interval now")
                            summary_writer.add_summary(ret[-1], global_counter)

                        #print(counter, summary_val_interval, counter % summary_val_interval)
                        if counter % summary_val_interval == 0:
                            print("summary_val_interval now")
                            batch = sample_test_data(
                                val_dir,
                                args.single_batch_size * cfg.GPU_USE_COUNT,
                                multi_gpu_sum=cfg.GPU_USE_COUNT)

                            ret = model.validate_step(sess,
                                                      batch,
                                                      summary=True)
                            summary_writer.add_summary(ret[-1], global_counter)

                            try:
                                ret = model.predict_step(sess,
                                                         batch,
                                                         summary=True)
                                summary_writer.add_summary(
                                    ret[-1], global_counter)
                            except:
                                print("prediction skipped due to error")

                        if check_if_should_pause(args.tag):
                            model.saver.save(sess,
                                             os.path.join(
                                                 save_model_dir, 'checkpoint'),
                                             global_step=model.global_step)
                            print('pause and save model @ {} steps:{}'.format(
                                save_model_dir, model.global_step.eval()))
                            sys.exit(0)

                        batch_time = time.time()
                    experiment.log_epoch_end(epoch)
                    sess.run(model.epoch_add_op)

                    model.saver.save(sess,
                                     os.path.join(save_model_dir,
                                                  'checkpoint'),
                                     global_step=model.global_step)

                    # dump test data every 10 epochs
                    if (epoch + 1) % 10 == 0:
                        # create output folder
                        os.makedirs(os.path.join(args.output_path, str(epoch)),
                                    exist_ok=True)
                        os.makedirs(os.path.join(args.output_path, str(epoch),
                                                 'data'),
                                    exist_ok=True)
                        if args.vis:
                            os.makedirs(os.path.join(args.output_path,
                                                     str(epoch), 'vis'),
                                        exist_ok=True)

                        for batch in iterate_data(
                                val_dir,
                                shuffle=False,
                                aug=False,
                                is_testset=False,
                                batch_size=args.single_batch_size *
                                cfg.GPU_USE_COUNT,
                                multi_gpu_sum=cfg.GPU_USE_COUNT):

                            if args.vis:
                                tags, results, front_images, bird_views, heatmaps = model.predict_step(
                                    sess, batch, summary=False, vis=True)
                            else:
                                tags, results = model.predict_step(
                                    sess, batch, summary=False, vis=False)

                            for tag, result in zip(tags, results):
                                of_path = os.path.join(args.output_path,
                                                       str(epoch), 'data',
                                                       tag + '.txt')
                                with open(of_path, 'w+') as f:
                                    labels = box3d_to_label(
                                        [result[:, 1:8]], [result[:, 0]],
                                        [result[:, -1]],
                                        coordinate='lidar')[0]
                                    for line in labels:
                                        f.write(line)
                                    print('write out {} objects to {}'.format(
                                        len(labels), tag))
                            # dump visualizations
                            if args.vis:
                                for tag, front_image, bird_view, heatmap in zip(
                                        tags, front_images, bird_views,
                                        heatmaps):
                                    front_img_path = os.path.join(
                                        args.output_path, str(epoch), 'vis',
                                        tag + '_front.jpg')
                                    bird_view_path = os.path.join(
                                        args.output_path, str(epoch), 'vis',
                                        tag + '_bv.jpg')
                                    heatmap_path = os.path.join(
                                        args.output_path, str(epoch), 'vis',
                                        tag + '_heatmap.jpg')
                                    cv2.imwrite(front_img_path, front_image)
                                    cv2.imwrite(bird_view_path, bird_view)
                                    cv2.imwrite(heatmap_path, heatmap)

                        # execute evaluation code
                        cmd_1 = "./kitti_eval/launch_test.sh"
                        cmd_2 = os.path.join(args.output_path, str(epoch))
                        cmd_3 = os.path.join(args.output_path, str(epoch),
                                             'log')
                        os.system(" ".join([cmd_1, cmd_2, cmd_3]))

            print('train done. total epoch:{} iter:{}'.format(
                epoch, model.global_step.eval()))

            # finallly save model
            model.saver.save(sess,
                             os.path.join(save_model_dir, 'checkpoint'),
                             global_step=model.global_step)
Beispiel #7
0
                evaluation_metrics = [
                    ("val_precision", precision.mean()),
                    ("val_recall", recall.mean()),
                    ("val_mAP", AP.mean()),
                    ("val_f1", f1.mean()),
                ]
                logger.list_of_scalars_summary(evaluation_metrics, epoch)
                with experiment.test():
                    experiment.log_metric("precision",
                                          precision.mean(),
                                          step=epoch)

                # Print class APs and mAP
                ap_table = [["Index", "Class name", "AP"]]
                for i, c in enumerate(ap_class):
                    ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
                print(AsciiTable(ap_table).table)
                print(f"---- mAP {AP.mean()}")
                with experiment.test():
                    experiment.log_metric("AP_baby", AP[0], step=epoch)
            else:
                print('CANNOT EVALUATE!!! ----------------------------')

        if epoch % opt.checkpoint_interval == 0:
            torch.save(model.state_dict(),
                       f"checkpoints/yolov3_ckpt_%d.pth" % epoch)

        experiment.log_epoch_end(epoch)

    experiment.end()
Beispiel #8
0
            or
            (mean_val_metrics['cd_recalls'] > best_metrics['cd_recalls'])
            or
            (mean_val_metrics['cd_f1scores'] > best_metrics['cd_f1scores'])):
        #Insert trainin and epoch information to metadata dictionary
        metadata['validation_metrics'] = mean_val_metrics

        # Save to comet.ml and in GCS
        with open('/tmp/metadata_epoch_' + str(epoch) + '.json', 'w') as fout:
            json.dump(metadata, fout)

        torch.save(model, '/tmp/checkpoint_epoch_'+str(epoch)+'.pt')
        upload_file_path = '/tmp/checkpoint_epoch_'+str(epoch)+'.pt'
        upload_metadata_file_path = '/tmp/metadata_epoch_' + str(epoch) + '.json'
        experiment.outputs_store.upload_file(upload_file_path)
        experiment.outputs_store.upload_file(upload_metadata_file_path)
        comet.log_asset(upload_metadata_file_path)
        best_metrics = mean_val_metrics

    # Log all train and validation metrics
    log_train_metrics = {"train_"+k: v for k, v in mean_train_metrics.items()}
    log_val_metrics = {"validate_"+k: v for k, v in mean_val_metrics.items()}
    epoch_metrics = {'epoch': epoch, **log_train_metrics, **log_val_metrics}

    experiment.log_metrics(**epoch_metrics)

    # Set experiment to running properly (for filtering out bad runs)
    comet.log_other('status', 'running')
    comet.log_epoch_end(epoch)
comet.log_other('status', 'complete')
class eICU_Operator(TrainingOperator):
    def setup(self, config):
        # Number of RaySGD workers
        self.num_workers = config.get('num_workers', 1)
        # Fetch the Comet ML credentials
        self.comet_ml_api_key = config['comet_ml_api_key']
        self.comet_ml_project_name = config['comet_ml_project_name']
        self.comet_ml_workspace = config['comet_ml_workspace']
        self.log_comet_ml = config.get('log_comet_ml', True)
        self.comet_ml_save_model = config.get('comet_ml_save_model', True)
        # Fetch model and dataset parameters
        self.model_class = config.get('model', 'VanillaRNN')  # Model class
        self.dataset_mode = config.get(
            'dataset_mode', 'one hot encoded'
        )  # The mode in which we'll use the data, either one hot encoded or pre-embedded
        self.ml_core = config.get(
            'ml_core', 'deep learning'
        )  # The core machine learning type we'll use; either traditional ML or DL
        self.use_delta_ts = config.get(
            'use_delta_ts',
            False)  # Indicates if we'll use time variation info
        self.time_window_h = config.get(
            'time_window_h',
            48)  # Number of hours on which we want to predict mortality
        # Additional properties and relevant training information
        self.step = 0  # Number of iteration steps done so far
        self.print_every = config.get(
            'print_every', 10)  # Steps interval where the metrics are printed
        self.val_loss_min = np.inf  # Start with an infinitely big minimum validation loss
        self.clip_value = config.get(
            'clip_value',
            0.5)  # Gradient clipping value, to avoid exploiding gradients
        self.features_list = config.get(
            'features_list',
            None)  # Names of the features being used in the current pipeline
        self.model_type = config.get(
            'model_type', 'multivariate_rnn')  # Type of model to train
        self.padding_value = config.get(
            'padding_value',
            999999)  # Value to use in the padding, to fill the sequences
        self.cols_to_remove = config.get(
            'cols_to_remove', [0, 1]
        )  # List of indices of columns to remove from the features before feeding to the model
        self.is_custom = config.get(
            'is_custom',
            False)  # Specifies if the model being used is a custom built one
        self.already_embedded = config.get(
            'already_embedded', False
        )  # Indicates if the categorical features are already embedded when fetching a batch
        self.batch_size = config.get(
            'batch_size', 32
        )  # The number of samples used in each training, validation or test iteration
        self.n_epochs = config.get(
            'n_epochs', 1
        )  # Number of epochs, i.e. the number of times to iterate through all of the training data
        self.lr = config.get('lr', 0.001)  # Learning rate
        self.models_path = config.get(
            'models_path',
            '')  # Path to the directory where the models are stored
        self.see_progress = config.get(
            'see_progress', True
        )  # Sets if a progress bar is shown for each training and validation loop
        # Register all the hyperparameters
        if self.num_workers == 1:
            model = self.model
        else:
            # Get the original model, as the current one is wrapped in DistributedDataParallel
            model = self.model.module
        model_args = inspect.getfullargspec(model.__init__).args[1:]
        self.hyper_params = dict([(param, getattr(model, param))
                                  for param in model_args])
        self.hyper_params.update({
            'batch_size': self.batch_size,
            'n_epochs': self.n_epochs,
            'learning_rate': self.lr
        })
        if self.log_comet_ml is True:
            # Create a new Comet.ml experiment
            self.experiment = Experiment(
                api_key=self.comet_ml_api_key,
                project_name=self.comet_ml_project_name,
                workspace=self.comet_ml_workspace,
                auto_param_logging=False,
                auto_metric_logging=False,
                auto_output_logging=False)
            self.experiment.log_other('completed', False)
            self.experiment.log_other('random_seed', du.random_seed)
            # Report hyperparameters to Comet.ml
            self.experiment.log_parameters(self.hyper_params)
            self.experiment.log_parameters(config)
            if self.features_list is not None:
                # Log the names of the features being used
                self.experiment.log_other('features_list', self.features_list)
        if self.clip_value is not None:
            # Set gradient clipping to avoid exploding gradients
            for p in self.model.parameters():
                p.register_hook(lambda grad: torch.clamp(
                    grad, -self.clip_value, self.clip_value))

    def set_model_filename(self, val_loss):
        # Start with the model class name
        if self.model_class == 'VanillaRNN':
            model_filename = 'rnn'
        elif self.model_class == 'VanillaLSTM':
            model_filename = 'lstm'
        elif self.model_class == 'TLSTM':
            model_filename = 'tlstm'
        elif self.model_class == 'MF1LSTM':
            model_filename = 'mf1lstm'
        elif self.model_class == 'MF2LSTM':
            model_filename = 'mf2lstm'
        else:
            raise Exception(
                f'ERROR: {self.model_class} is an invalid model type. Please specify either "VanillaRNN", "VanillaLSTM", "TLSTM", "MF1LSTM" or "MF2LSTM".'
            )
        # Add dataset mode information
        if self.dataset_mode == 'pre-embedded':
            model_filename = model_filename + '_pre_embedded'
        elif self.dataset_mode == 'learn embedding':
            model_filename = model_filename + '_with_embedding'
        elif self.dataset_mode == 'one hot encoded':
            model_filename = model_filename + '_one_hot_encoded'
        # Use of time variation information
        if self.use_delta_ts is not False and (self.model_class == 'VanillaRNN'
                                               or self.model_class
                                               == 'VanillaLSTM'):
            model_filename = model_filename + '_delta_ts'
        # Add the validation loss and timestamp
        current_datetime = datetime.now().strftime('%d_%m_%Y_%H_%M')
        model_filename = f'{val_loss:.4f}_valloss_{model_filename}_{current_datetime}.pth'
        return model_filename

    @override(TrainingOperator)
    def validate(self, val_iterator, info):
        # Number of iteration steps done so far
        step = info.get('step', 0)
        # Initialize the validation metrics
        val_loss = 0
        val_acc = 0
        val_auc = list()
        if self.num_workers == 1:
            model = self.model
        else:
            # Get the original model, as the current one is wrapped in DistributedDataParallel
            model = self.model.module
        if model.n_outputs > 1:
            val_auc_wgt = list()
        # Loop through the validation data
        for features, labels in du.utils.iterations_loop(
                val_iterator, see_progress=self.see_progress,
                desc='Val batches'):
            # Turn off gradients for validation, saves memory and computations
            with torch.no_grad():
                if self.is_custom is False:
                    # Find the original sequence lengths
                    seq_lengths = du.search_explore.find_seq_len(
                        labels, padding_value=self.padding_value)
                else:
                    # No need to find the sequence lengths now
                    seq_lengths = None
                if self.use_gpu is True:
                    # Move data to GPU
                    features, labels = features.to(self.device), labels.to(
                        self.device)
                # Do inference on the data
                if self.model_type.lower() == 'multivariate_rnn':
                    (pred, correct_pred, scores, labels,
                     loss) = (du.deep_learning.inference_iter_multi_var_rnn(
                         self.model,
                         features,
                         labels,
                         padding_value=self.padding_value,
                         cols_to_remove=self.cols_to_remove,
                         is_train=False,
                         prob_output=True,
                         is_custom=self.is_custom,
                         already_embedded=self.already_embedded,
                         seq_lengths=seq_lengths,
                         distributed_train=(self.num_workers > 1)))
                elif self.model_type.lower() == 'mlp':
                    pred, correct_pred, scores, loss = (
                        du.deep_learning.inference_iter_mlp(
                            self.model,
                            features,
                            labels,
                            self.cols_to_remove,
                            is_train=False,
                            prob_output=True))
                else:
                    raise Exception(
                        f'ERROR: Invalid model type. It must be "multivariate_rnn" or "mlp", not {self.model_type}.'
                    )
                val_loss += loss  # Add the validation loss of the current batch
                val_acc += torch.mean(
                    correct_pred.type(torch.FloatTensor)
                )  # Add the validation accuracy of the current batch, ignoring all padding values
                if self.use_gpu is True:
                    # Move data to CPU for performance computations
                    scores, labels = scores.cpu(), labels.cpu()
                # Add the training ROC AUC of the current batch
                if model.n_outputs == 1:
                    try:
                        val_auc.append(
                            roc_auc_score(labels.numpy(),
                                          scores.detach().numpy()))
                    except Exception as e:
                        warnings.warn(
                            f'Couldn\'t calculate the validation AUC on step {step}. Received exception "{str(e)}".'
                        )
                else:
                    # It might happen that not all labels are present in the current batch;
                    # as such, we must focus on the ones that appear in the batch
                    labels_in_batch = labels.unique().long()
                    try:
                        val_auc.append(
                            roc_auc_score(labels.numpy(),
                                          softmax(scores[:, labels_in_batch],
                                                  dim=1).detach().numpy(),
                                          multi_class='ovr',
                                          average='macro',
                                          labels=labels_in_batch.numpy()))
                        # Also calculate a weighted version of the AUC; important for imbalanced dataset
                        val_auc_wgt.append(
                            roc_auc_score(labels.numpy(),
                                          softmax(scores[:, labels_in_batch],
                                                  dim=1).detach().numpy(),
                                          multi_class='ovr',
                                          average='weighted',
                                          labels=labels_in_batch.numpy()))
                    except Exception as e:
                        warnings.warn(
                            f'Couldn\'t calculate the validation AUC on step {step}. Received exception "{str(e)}".'
                        )
                # Remove the current features and labels from memory
                del features
                del labels
        # Calculate the average of the metrics over the batches
        val_loss = val_loss / len(val_iterator)
        val_acc = val_acc / len(val_iterator)
        val_auc = np.mean(val_auc)
        if model.n_outputs > 1:
            val_auc_wgt = np.mean(val_auc_wgt)
        # Return the validation metrics
        metrics = dict(val_loss=val_loss, val_acc=val_acc, val_auc=val_auc)
        if model.n_outputs > 1:
            metrics['val_auc_wgt'] = val_auc_wgt
        return metrics

    @override(TrainingOperator)
    def train_epoch(self, iterator, info):
        if self.num_workers == 1:
            model = self.model
        else:
            # Get the original model, as the current one is wrapped in DistributedDataParallel
            model = self.model.module
        print(f'DEBUG: TrainingOperator attributes:\n{vars(self)}')
        print(f'DEBUG: Model\'s attributes:\n{vars(model)}')
        # Register the current epoch
        epoch = info.get('epoch_idx', 0)
        # Number of iteration steps done so far
        step = info.get('step', 0)
        # Initialize the training metrics
        train_loss = 0
        train_acc = 0
        train_auc = list()
        if model.n_outputs > 1:
            train_auc_wgt = list()
        # try:
        # Loop through the training data
        for features, labels in du.utils.iterations_loop(
                iterator, see_progress=self.see_progress, desc='Steps'):
            # Activate dropout to train the model
            self.model.train()
            # Clear the gradients of all optimized variables
            self.optimizer.zero_grad()
            if self.is_custom is False:
                # Find the original sequence lengths
                seq_lengths = du.search_explore.find_seq_len(
                    labels, padding_value=self.padding_value)
            else:
                # No need to find the sequence lengths now
                seq_lengths = None
            if self.use_gpu is True:
                # Move data to GPU
                features, labels = features.to(self.device), labels.to(
                    self.device)
            # Do inference on the data
            if self.model_type.lower() == 'multivariate_rnn':
                (pred, correct_pred, scores, labels, step_train_loss) = (
                    du.deep_learning.inference_iter_multi_var_rnn(
                        self.model,
                        features,
                        labels,
                        padding_value=self.padding_value,
                        cols_to_remove=self.cols_to_remove,
                        is_train=True,
                        prob_output=True,
                        optimizer=self.optimizer,
                        is_custom=self.is_custom,
                        already_embedded=self.already_embedded,
                        seq_lengths=seq_lengths,
                        distributed_train=(self.num_workers > 1)))
            elif self.model_type.lower() == 'mlp':
                pred, correct_pred, scores,
                step_train_loss = (du.deep_learning.inference_iter_mlp(
                    self.model,
                    features,
                    labels,
                    self.cols_to_remove,
                    is_train=True,
                    prob_output=True,
                    optimizer=self.optimizer))
            else:
                raise Exception(
                    f'ERROR: Invalid model type. It must be "multivariate_rnn" or "mlp", not {self.model_type}.'
                )
            # Add the training loss and accuracy of the current batch
            train_loss += step_train_loss
            step_train_acc = torch.mean(correct_pred.type(torch.FloatTensor))
            train_acc += step_train_acc
            if self.use_gpu is True:
                # Move data to CPU for performance computations
                scores, labels = scores.cpu(), labels.cpu()
            # Add the training ROC AUC of the current batch
            if model.n_outputs == 1:
                try:
                    step_train_auc = roc_auc_score(labels.numpy(),
                                                   scores.detach().numpy())
                    train_auc.append(step_train_auc)
                except Exception as e:
                    warnings.warn(
                        f'Couldn\'t calculate the training AUC on step {step}. Received exception "{str(e)}".'
                    )
                    step_train_auc = None
            else:
                # It might happen that not all labels are present in the current batch;
                # as such, we must focus on the ones that appear in the batch
                labels_in_batch = labels.unique().long()
                try:
                    step_train_auc = roc_auc_score(
                        labels.numpy(),
                        softmax(scores[:, labels_in_batch],
                                dim=1).detach().numpy(),
                        multi_class='ovr',
                        average='macro',
                        labels=labels_in_batch.numpy())
                    train_auc.append(step_train_auc)
                    # Also calculate a weighted version of the AUC; important for imbalanced dataset
                    step_train_auc_wgt = roc_auc_score(
                        labels.numpy(),
                        softmax(scores[:, labels_in_batch],
                                dim=1).detach().numpy(),
                        multi_class='ovr',
                        average='weighted',
                        labels=labels_in_batch.numpy())
                    train_auc_wgt.append(step_train_auc_wgt)
                except Exception as e:
                    warnings.warn(
                        f'Couldn\'t calculate the training AUC on step {step}. Received exception "{str(e)}".'
                    )
                    step_train_auc = None
                    step_train_auc_wgt = None
            # Count one more iteration step
            step += 1
            info['step'] = step
            # Deactivate dropout to test the model
            self.model.eval()
            # Remove the current features and labels from memory
            del features
            del labels
            # Run the current model on the validation set
            val_metrics = self.validate(self.validation_loader, info)
            if self.log_comet_ml is True:
                # Upload the current step's metrics to Comet ML
                self.experiment.log_metric('train_loss',
                                           step_train_loss,
                                           step=step)
                self.experiment.log_metric('train_acc',
                                           step_train_acc,
                                           step=step)
                self.experiment.log_metric('train_auc',
                                           step_train_auc,
                                           step=step)
                self.experiment.log_metric('val_loss',
                                           val_metrics['val_loss'],
                                           step=step)
                self.experiment.log_metric('val_acc',
                                           val_metrics['val_acc'],
                                           step=step)
                self.experiment.log_metric('val_auc',
                                           val_metrics['val_auc'],
                                           step=step)
                if model.n_outputs > 1:
                    self.experiment.log_metric('train_auc_wgt',
                                               step_train_auc_wgt,
                                               step=step)
                    self.experiment.log_metric('val_auc_wgt',
                                               val_metrics['val_auc_wgt'],
                                               step=step)
            # Display validation loss
            if step % self.print_every == 0:
                print(
                    f'Epoch {epoch} step {step}: Validation loss: {val_metrics["val_loss"]}; Validation Accuracy: {val_metrics["val_acc"]}; Validation AUC: {val_metrics["val_auc"]}'
                )
            # Check if the performance obtained in the validation set is the best so far (lowest loss value)
            if val_metrics['val_loss'] < self.val_loss_min:
                print(
                    f'New minimum validation loss: {self.val_loss_min} -> {val_metrics["val_loss"]}.'
                )
                # Update the minimum validation loss
                self.val_loss_min = val_metrics['val_loss']
                # Filename and path where the model will be saved
                model_filename = self.set_model_filename(
                    val_metrics['val_loss'])
                print(f'Saving model in {model_filename}')
                # Save the best performing model so far, along with additional information to implement it
                checkpoint = self.hyper_params
                checkpoint['state_dict'] = self.model.state_dict()
                torch.save(checkpoint, model_filename)
                # [TODO] Check if this really works locally or if it just saves in the temporary nodes
                # self.save(checkpoint, f'{self.models_path}{model_filename}')
                if self.log_comet_ml is True and self.comet_ml_save_model is True:
                    # Upload the model to Comet.ml
                    self.experiment.log_model(name=model_filename,
                                              file_or_folder=model_filename,
                                              overwrite=True)
        # except Exception as e:
        #     warnings.warn(f'There was a problem doing training epoch {epoch}. Ending current epoch. Original exception message: "{str(e)}"')
        # try:
        # Calculate the average of the metrics over the epoch
        train_loss = train_loss / len(iterator)
        train_acc = train_acc / len(iterator)
        train_auc = np.mean(train_auc)
        if model.n_outputs > 1:
            train_auc_wgt = np.mean(train_auc_wgt)
        # Remove attached gradients so as to be able to print the values
        train_loss, val_loss = train_loss.detach(
        ), val_metrics['val_loss'].detach()
        if self.use_gpu is True:
            # Move metrics data to CPU
            train_loss, val_loss = train_loss.cpu(), val_loss.cpu()
        if self.log_comet_ml is True:
            # Upload the current epoch's metrics to Comet ML
            self.experiment.log_metric('train_loss', train_loss, epoch=epoch)
            self.experiment.log_metric('train_acc', train_acc, epoch=epoch)
            self.experiment.log_metric('train_auc', train_auc, epoch=epoch)
            self.experiment.log_metric('val_loss', val_loss, epoch=epoch)
            self.experiment.log_metric('val_acc',
                                       val_metrics['val_acc'],
                                       epoch=epoch)
            self.experiment.log_metric('val_auc',
                                       val_metrics['val_auc'],
                                       epoch=epoch)
            self.experiment.log_epoch_end(epoch, epoch=step)
            if model.n_outputs > 1:
                self.experiment.log_metric('train_auc_wgt',
                                           train_auc_wgt,
                                           epoch=epoch)
                self.experiment.log_metric('val_auc_wgt',
                                           val_metrics['val_auc_wgt'],
                                           epoch=epoch)
        # Print a report of the epoch
        print(
            f'Epoch {epoch}: Training loss: {train_loss}; Training Accuracy: {train_acc}; Training AUC: {train_auc}; \
                Validation loss: {val_loss}; Validation Accuracy: {val_metrics["val_acc"]}; Validation AUC: {val_metrics["val_auc"]}'
        )
        print('----------------------')
        # except Exception as e:
        #     warnings.warn(f'There was a problem printing metrics from epoch {epoch}. Original exception message: "{str(e)}"')
        # Return the training metrics
        metrics = dict(train_loss=train_loss,
                       train_acc=train_acc,
                       train_auc=train_auc)
        if model.n_outputs > 1:
            metrics['train_auc_wgt'] = train_auc_wgt
        return metrics