Ejemplo n.º 1
0
def main(train=True, evaluate=True, model_file=None):
	# Params
	k = 100
	alpha = 0.1
	kappa = 0.5
	tau0 = 64
	var_i = 50
	size = 2000
	V = read_count('../../dataset/nyt/nytimes_voca_count.txt')
	if V == None:
		return
	# Model
	if model_file:
		ldaModel = OnlineLDAVB.load(model_file)
	else:
		ldaModel = OnlineLDAVB(alpha=alpha, K=k, V=V, kappa=kappa, tau0=tau0,\
				batch_size=size, var_max_iter=var_i)
	# Result directory
	result_dir = make_result_dir(ldaModel, '../result/nyt/')
	# Train
	if train:
		# Number of lines each time
		n = 2000
		# Number of documents
		D = read_count('../../dataset/nyt/document_counts.txt')
		if D == None:
			return
		# File
		corpus_file = open('../../dataset/nyt/nytimes_row_200k.txt')
		for i in xrange(int(math.ceil(D/n))):
			# Read n docs
			W = read_n_docs(corpus_file, n)
			# Train
			train_model(ldaModel, W, result_dir)
		# Save model
		save_model(ldaModel, result_dir)	
		# Save stop words
		save_top_words(ldaModel, read_inverse_dic('../../dataset/nyt/vocab.nytimes.txt'), \
				result_dir)
		# Close file
		corpus_file.close()

	# Evaluate
	if evaluate:
		# Read test set
		W_tests = []
		for i in range(1, 11):
			W_obs = read('../../dataset/nyt/data_test_%d_part_1' % i)
			W_he = read('../../dataset/nyt/data_test_%d_part_2' % i)
			ids = [j for j in range(len(W_obs)) if W_obs[j].num_words > 0 \
					and W_he[j].num_words > 0]
			W_obs = np.array(W_obs)[ids]
			W_he = np.array(W_he)[ids]
			W_tests.append((W_obs, W_he))
		# Evaluate and save result
		evaluator = Predictive(ldaModel)
		evaluate_model(evaluator, W_tests, result_dir)
Ejemplo n.º 2
0
    def step(self):
        state = self.env.reset()
        for iteration in range(1, 1 + self.n_iterations):
            states = []
            actions = []
            rewards = []
            values = []
            log_probs = []
            dones = []

            self.start_time = time.time()
            for t in range(self.horizon):
                # self.state_rms.update(state)
                state = np.clip((state - self.state_rms.mean) /
                                (self.state_rms.var**0.5 + 1e-8), -5, 5)
                dist = self.agent.choose_dist(state)
                action = dist.sample().cpu().numpy()[0]
                # action = np.clip(action, self.agent.action_bounds[0], self.agent.action_bounds[1])
                log_prob = dist.log_prob(torch.Tensor(action))
                value = self.agent.get_value(state)
                next_state, reward, done, _ = self.env.step(action)

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                values.append(value)
                log_probs.append(log_prob)
                dones.append(done)

                if done:
                    state = self.env.reset()
                else:
                    state = next_state
            # self.state_rms.update(next_state)
            next_state = np.clip((next_state - self.state_rms.mean) /
                                 (self.state_rms.var**0.5 + 1e-8), -5, 5)
            next_value = self.agent.get_value(next_state) * (1 - done)
            values.append(next_value)

            advs = self.get_gae(rewards, values, dones)
            states = np.vstack(states)
            actor_loss, critic_loss = self.train(states, actions, advs, values,
                                                 log_probs)
            # self.agent.set_weights()
            self.agent.schedule_lr()
            eval_rewards = evaluate_model(self.agent, self.test_env,
                                          self.state_rms,
                                          self.agent.action_bounds)
            self.state_rms.update(states)
            self.print_logs(iteration, actor_loss, critic_loss, eval_rewards)
Ejemplo n.º 3
0
def main_train_loop(save_dir, model, args):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_class = len(args.cates)
    #resume chekckpoint
    start_epoch = 0
    optimizer = initilize_optimizer(model, args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    #initilize dataset and load
    tr_dataset, te_dataset = get_datasets(args)

    train_sampler = None  # for non distributed training

    train_loader = torch.utils.data.DataLoader(dataset=tr_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               worker_init_fn=np.random.seed(
                                                   args.seed))
    test_loader = torch.utils.data.DataLoader(dataset=te_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True,
                                              drop_last=False,
                                              worker_init_fn=np.random.seed(
                                                  args.seed))

    #initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.epochs // 2,
                                              gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l

        scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                lr_lambda=lambda_rule)
    else:
        assert 0, "args.schedulers should be either 'exponential' or 'linear'"

    #training starts from here
    tot_nelbo = []
    tot_kl_loss = []
    tot_x_reconst = []

    best_eval_metric = float('+inf')

    for epoch in range(start_epoch, args.epochs):
        # adjust the learning rate
        if (epoch + 1) % args.exp_decay_freq == 0:
            scheduler.step(epoch=epoch)
        #train for one epoch
        model.train()
        for bidx, data in enumerate(train_loader):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            obj_type = data['cate_idx']
            y_one_hot = obj_type.new(
                np.eye(n_class)[obj_type]).to(device).float()
            step = bidx + len(train_loader) * epoch

            if args.random_rotate:
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)

            inputs = tr_batch.to(device)
            y_one_hot = y_one_hot.to(device)
            optimizer.zero_grad()
            inputs_dict = {'x': inputs, 'y_class': y_one_hot}
            ret = model(inputs_dict)
            loss, nelbo, kl_loss, x_reconst, cl_loss = ret['loss'], ret[
                'nelbo'], ret['kl_loss'], ret['x_reconst'], ret['cl_loss']
            loss.backward()
            optimizer.step()

            cur_loss = loss.cpu().item()
            cur_nelbo = nelbo.cpu().item()
            cur_kl_loss = kl_loss.cpu().item()
            cur_x_reconst = x_reconst.cpu().item()
            cur_cl_loss = cl_loss.cpu().item()
            tot_nelbo.append(cur_nelbo)
            tot_kl_loss.append(cur_kl_loss)
            tot_x_reconst.append(cur_x_reconst)
            if step % args.log_freq == 0:
                print(
                    "Epoch {0:6d} Step {1:12d} Loss {2:12.6f} Nelbo {3:12.6f} KL Loss {4:12.6f} Reconst Loss {5:12.6f} CL_Loss{6:12.6f}"
                    .format(epoch, step, cur_loss, cur_nelbo, cur_kl_loss,
                            cur_x_reconst, cur_cl_loss))

        #save checkpoint
        if (epoch + 1) % args.save_freq == 0:
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-latest.pt'))

            eval_metric = evaluate_model(model, te_dataset, args)
            train_metric = evaluate_model(model, tr_dataset, args)

            print('Checkpoint: Dev Reconst Loss:{0}, Train Reconst Loss:{1}'.
                  format(eval_metric, train_metric))
            if eval_metric < best_eval_metric:
                best_eval_metric = eval_metric
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-best.pt'))
                print('new best model found!')

    save(model, optimizer, args.epochs,
         os.path.join(save_dir, 'checkpoint-latest.pt'))
    #save final visuliztion of 10 samples
    model.eval()
    with torch.no_grad():
        samples_A = model.reconstruct_input(inputs)  #sample_point(5)
        results = []
        for idx in range(5):
            res = visualize_point_clouds(
                samples_A[idx],
                tr_batch[idx],
                idx,
                pert_order=train_loader.dataset.display_axis_order)
            results.append(res)
        res = np.concatenate(results, axis=1)
        imsave(os.path.join(save_dir, 'images', '_epoch%d.png' % (epoch)),
               res.transpose((1, 2, 0)))

    #load the best model and compute eval metric:
    best_model_path = os.path.join(save_dir, 'checkpoint-best.pt')
    ckpt = torch.load(best_model_path)
    model.load_state_dict(ckpt['model'], strict=True)
    eval_metric = evaluate_model(model, te_dataset, args)
    train_metric = evaluate_model(model, tr_dataset, args)
    print(
        'Best model at epoch:{2} Dev Reconst Loss:{0}, Train Reconst Loss:{1}'.
        format(eval_metric, train_metric, ckpt['epoch']))
Ejemplo n.º 4
0
def train(opts):
    """ Trains the model """
    torch.manual_seed(opts.seed)

    vocab = vocabs.load_vocabs_from_file(opts.vocab)

    dataset = Seq2VecDataset(opts.training_dir, vocab, opts.langs)

    num_training_data = int(len(dataset) * opts.train_val_ratio)
    num_val_data = len(dataset) - num_training_data

    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [num_training_data, num_val_data])

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=2,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opts.batch_size,
        shuffle=True,
        pin_memory=(opts.device.type == "cuda"),
        num_workers=2,
    )

    model = Seq2VecNN(len(vocab.get_word2id()),
                      2,
                      num_neurons_per_layer=[100, 25])
    model = model.to(opts.device)

    patience = opts.patience
    num_epochs = opts.epochs

    if opts.patience is None:
        patience = float("inf")
    else:
        num_epochs = float("inf")

    best_val_loss = float("inf")

    num_poor = 0
    epoch = 1

    optimizer = torch.optim.Adam(model.parameters(), lr=opts.learning_rate)

    if opts.resume_from_checkpoint and os.path.isfile(
            opts.resume_from_checkpoint):
        print("Loading from checkpoint")
        best_val_loss, num_poor, epoch = load_checkpoint(
            opts.resume_from_checkpoint, model, optimizer)
        print(
            f"Previous state > Epoch {epoch}: Val loss={best_val_loss}, num_poor={num_poor}"
        )

    while epoch <= num_epochs and num_poor < patience:

        loss_function = torch.nn.CrossEntropyLoss()

        # Train
        train_loss, train_accuracy = train_for_one_epoch(
            model, loss_function, optimizer, train_dataloader, opts.device)

        # Evaluate the model
        eval_loss, eval_accuracy = test.evaluate_model(model, loss_function,
                                                       val_dataloader,
                                                       opts.device)

        print(
            f"Epoch={epoch} Train-Loss={train_loss} Train-Acc={train_accuracy} Test-Loss={eval_loss} Test-Acc={eval_accuracy} Num-Poor={num_poor}"
        )

        model.cpu()
        if eval_loss >= best_val_loss:
            num_poor += 1

        else:
            num_poor = 0
            best_eval_loss = eval_loss

            print("Saved model")
            torch.save(model.state_dict(), opts.model_path)

        save_checkpoint(
            opts.save_checkpoint_to,
            model,
            optimizer,
            best_val_loss,
            num_poor,
            epoch,
        )
        print("Saved checkpoint")
        model.to(opts.device)

        epoch += 1

    if epoch > num_epochs:
        print(f"Finished {num_epochs} epochs")
    else:
        print(f"Loss did not improve after {patience} epochs")
Ejemplo n.º 5
0
    def train_model(self, model, epoch):
        dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True)

        # train on gpu
        if torch.cuda.is_available():
            model.cuda()

        steps = 0
        model.train()
        last_idx = math.floor(len(self.train_dataset) / self.config.batch_size)
        train_acc_epoch = 0
        for idx, batch in enumerate(dataloader):
            # if dataset exhausted, reset dataloader and reshuffle (safeguard)
            if(idx == last_idx - 1):
                dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True)

            title = batch[0] # shape: (batch_size, title_pad_length, token_amount, embedding_size)
            abstract = batch[1] # shape: (batch_size, abstract_pad_length, token_amount, embedding_size)
            target = batch[2] # shape: (batch_size)
            target = torch.autograd.Variable(target).long()

            if torch.cuda.is_available():
                title = title.cuda()
                abstract = abstract.cuda()
                target = target.cuda()

            # convert to float tensor
            title = title.float()
            abstract = abstract.float()

            self.optim.zero_grad()
            prediction = model(title, abstract)
            loss = self.loss_fn(prediction, target)
            num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()
            acc = 100.0 * num_corrects / len(batch[0])
            self.train_losses.append(loss.item())
            loss.backward()
            self.clip_gradient(model, 1e-1)
            self.optim.step()
            steps += 1

            train_acc_epoch += acc.item()

            if steps % self.config.print_frequency == 0:
                print(f'Epoch: {epoch}, Idx: {idx + 1}, Training Loss: {loss.item():.4f}, Training Accuracy: {acc.item(): .2f}%')

        train_acc_epoch /= last_idx - 1

        # evaluate
        results = test.evaluate_model(model, self.test_dataset, self.config.batch_size)
        test_acc_epoch = results["accuracy"]
        print(f"Accuracy at the end of epoch {epoch}: {test_acc_epoch}")

        # save values at the end of each epoch for plotting
        self.train_accs.append(train_acc_epoch)
        self.test_accs.append(test_acc_epoch)

        # halve learning rate every 5 epochs
        if epoch % self.config.lr_half_integral == 0 and epoch != 0:
            for g in self.optim.param_groups:

                tmp_state_dict = self.optim.state_dict()
                self.current_lr /= 2
                self.optim.param_groups[0]['lr'] = self.current_lr
                # self.optim = torch.optim.Adam(model.parameters(), lr=self.current_lr, weight_decay=self.config.weight_decay)
                # self.optim.load_state_dict(tmp_state_dict)
                print("Halving Learning Rate. New Learning Rate: ", self.current_lr)

        if epoch % 1 == 0:
            # Save model and values
            torch.save({
                'model': model.state_dict(),
                'optim': self.optim.state_dict(),
                'parameters': (epoch),
                'train_losses': self.train_losses,
                'train_accs': self.train_accs,
                'test_accs': self.test_accs,
                'best_test_acc': self.best_test_acc,
                'current_lr': self.current_lr
            }, f'model/trained_{epoch}.pth')
            print(f'Epoch {epoch}, model successfully saved.')

        # save separately if best model
        if (test_acc_epoch > self.best_test_acc):
            self.best_test_acc = test_acc_epoch
            torch.save({
                'model': model.state_dict(),
                'optim': self.optim.state_dict(),
                'parameters': (epoch),
                'train_losses': self.train_losses,
                'train_accs': self.train_accs,
                'test_accs': self.test_accs,
                'best_test_acc': self.best_test_acc,
                'current_lr': self.current_lr
            }, f'model/BestModel.pth')
            print(f'Epoch {epoch}, *BestModel* successfully saved.')
Ejemplo n.º 6
0
            # Train model
            tbCallBack = TensorBoard(log_dir=outfolder,
                                     histogram_freq=0,
                                     write_graph=True,
                                     write_images=True)
            model.fit(x_train,
                      y_train,
                      validation_split=0.2,
                      epochs=25,
                      batch_size=250,
                      callbacks=[tbCallBack])
            save(model, path=outfolder)

            # Evaluate model
            scores = evaluate_model(model, x_test, y_test)

            # Print scores
            print('\n\n')
            print("Loss: ", backend.get_value(scores[0]))
            print("Accuracy: ", backend.get_value(scores[1]) * 100, "%")

    # # Create model and print to log file
    # model = tanh_model()
    # # model = relu_model()
    # print_model_summary(model)

    # # Train model
    # tbCallBack = TensorBoard(log_dir=outfolder, histogram_freq=0, write_graph=True, write_images=True)
    # model.fit(x_train, y_train, validation_split=0.2, epochs=10, batch_size=250, callbacks=[tbCallBack])
    # save(model, path=outfolder)
Ejemplo n.º 7
0
def train_model(
    name="",
    resume="",
    base_dir=utils.BASE_DIR,
    model_name="v0",
    chosen_diseases=None,
    n_epochs=10,
    batch_size=4,
    oversample=False,
    max_os=None,
    shuffle=False,
    opt="sgd",
    opt_params={},
    loss_name="wbce",
    loss_params={},
    train_resnet=False,
    log_metrics=None,
    flush_secs=120,
    train_max_images=None,
    val_max_images=None,
    test_max_images=None,
    experiment_mode="debug",
    save=True,
    save_cms=True,  # Note that in this case, save_cms (to disk) includes write_cms (to TB)
    write_graph=False,
    write_emb=False,
    write_emb_img=False,
    write_img=False,
    image_format="RGB",
    multiple_gpu=False,
):

    # Choose GPU
    device = utilsT.get_torch_device()
    print("Using device: ", device)

    # Common folders
    dataset_dir = os.path.join(base_dir, "dataset")

    # Dataset handling
    print("Loading train dataset...")
    train_dataset, train_dataloader = utilsT.prepare_data(
        dataset_dir,
        "train",
        chosen_diseases,
        batch_size,
        oversample=oversample,
        max_os=max_os,
        shuffle=shuffle,
        max_images=train_max_images,
        image_format=image_format,
    )
    train_samples, _ = train_dataset.size()

    print("Loading val dataset...")
    val_dataset, val_dataloader = utilsT.prepare_data(
        dataset_dir,
        "val",
        chosen_diseases,
        batch_size,
        max_images=val_max_images,
        image_format=image_format,
    )
    val_samples, _ = val_dataset.size()

    # Should be the same than chosen_diseases
    chosen_diseases = list(train_dataset.classes)
    print("Chosen diseases: ", chosen_diseases)

    if resume:
        # Load model and optimizer
        model, model_name, optimizer, opt, loss_name, loss_params, chosen_diseases = models.load_model(
            base_dir, resume, experiment_mode="", device=device)
        model.train(True)
    else:
        # Create model
        model = models.init_empty_model(model_name,
                                        chosen_diseases,
                                        train_resnet=train_resnet).to(device)

        # Create optimizer
        OptClass = optimizers.get_optimizer_class(opt)
        optimizer = OptClass(model.parameters(), **opt_params)
        # print("OPT: ", opt_params)

    # Allow multiple GPUs
    if multiple_gpu:
        model = DataParallel(model)

    # Tensorboard log options
    run_name = utils.get_timestamp()
    if name:
        run_name += "_{}".format(name)

    if len(chosen_diseases) == 1:
        run_name += "_{}".format(chosen_diseases[0])
    elif len(chosen_diseases) == 14:
        run_name += "_all"

    log_dir = get_log_dir(base_dir, run_name, experiment_mode=experiment_mode)

    print("Run name: ", run_name)
    print("Saved TB in: ", log_dir)

    writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)

    # Create validator engine
    validator = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           False))

    val_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    val_loss.attach(validator, loss_name)

    utilsT.attach_metrics(validator, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(validator, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(validator, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(validator, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    # Create trainer engine
    trainer = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           True))

    train_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    train_loss.attach(trainer, loss_name)

    utilsT.attach_metrics(trainer, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 step=Events.EPOCH_COMPLETED)

    # TODO: Early stopping
    #     def score_function(engine):
    #         val_loss = engine.state.metrics[loss_name]
    #         return -val_loss

    #     handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
    #     validator.add_event_handler(Events.COMPLETED, handler)

    # Metrics callbacks
    if log_metrics is None:
        log_metrics = list(ALL_METRICS)

    def _write_metrics(run_type, metrics, epoch, wall_time):
        loss = metrics.get(loss_name, 0)

        writer.add_scalar("Loss/" + run_type, loss, epoch, wall_time)

        for metric_base_name in log_metrics:
            for disease in chosen_diseases:
                metric_value = metrics.get(
                    "{}_{}".format(metric_base_name, disease), -1)
                writer.add_scalar(
                    "{}_{}/{}".format(metric_base_name, disease, run_type),
                    metric_value, epoch, wall_time)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_write_metrics(trainer):
        epoch = trainer.state.epoch
        max_epochs = trainer.state.max_epochs

        # Run on evaluation
        validator.run(val_dataloader, 1)

        # Common time
        wall_time = time.time()

        # Log all metrics to TB
        _write_metrics("train", trainer.state.metrics, epoch, wall_time)
        _write_metrics("val", validator.state.metrics, epoch, wall_time)

        train_loss = trainer.state.metrics.get(loss_name, 0)
        val_loss = validator.state.metrics.get(loss_name, 0)

        tb_write_histogram(writer, model, epoch, wall_time)

        print("Finished epoch {}/{}, loss {:.3f}, val loss {:.3f} (took {})".
              format(epoch, max_epochs, train_loss, val_loss,
                     utils.duration_to_str(int(timer._elapsed()))))

    # Hparam dict
    hparam_dict = {
        "resume": resume,
        "n_diseases": len(chosen_diseases),
        "diseases": ",".join(chosen_diseases),
        "n_epochs": n_epochs,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "model_name": model_name,
        "opt": opt,
        "loss": loss_name,
        "samples (train, val)": "{},{}".format(train_samples, val_samples),
        "train_resnet": train_resnet,
        "multiple_gpu": multiple_gpu,
    }

    def copy_params(params_dict, base_name):
        for name, value in params_dict.items():
            hparam_dict["{}_{}".format(base_name, name)] = value

    copy_params(loss_params, "loss")
    copy_params(opt_params, "opt")
    print("HPARAM: ", hparam_dict)

    # Train
    print("-" * 50)
    print("Training...")
    trainer.run(train_dataloader, n_epochs)

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = utils.duration_to_str(int(secs_per_epoch))
    print("Average time per epoch: ", duration_per_epoch)
    print("-" * 50)

    ## Write all hparams
    hparam_dict["duration_per_epoch"] = duration_per_epoch

    # FIXME: this is commented to avoid having too many hparams in TB frontend
    # metrics
    #     def copy_metrics(engine, engine_name):
    #         for metric_name, metric_value in engine.state.metrics.items():
    #             hparam_dict["{}_{}".format(engine_name, metric_name)] = metric_value
    #     copy_metrics(trainer, "train")
    #     copy_metrics(validator, "val")

    print("Writing TB hparams")
    writer.add_hparams(hparam_dict, {})

    # Save model to disk
    if save:
        print("Saving model...")
        models.save_model(base_dir, run_name, model_name, experiment_mode,
                          hparam_dict, trainer, model, optimizer)

    # Write graph to TB
    if write_graph:
        print("Writing TB graph...")
        tb_write_graph(writer, model, train_dataloader, device)

    # Write embeddings to TB
    if write_emb:
        print("Writing TB embeddings...")
        image_size = 256 if write_emb_img else 0

        # FIXME: be able to select images (balanced, train vs val, etc)
        image_list = list(train_dataset.label_index["FileName"])[:1000]
        # disease = chosen_diseases[0]
        # positive = train_dataset.label_index[train_dataset.label_index[disease] == 1]
        # negative = train_dataset.label_index[train_dataset.label_index[disease] == 0]
        # positive_images = list(positive["FileName"])[:25]
        # negative_images = list(negative["FileName"])[:25]
        # image_list = positive_images + negative_images

        all_images, all_embeddings, all_predictions, all_ground_truths = gen_embeddings(
            model,
            train_dataset,
            device,
            image_list=image_list,
            image_size=image_size)
        tb_write_embeddings(
            writer,
            chosen_diseases,
            all_images,
            all_embeddings,
            all_predictions,
            all_ground_truths,
            global_step=n_epochs,
            use_images=write_emb_img,
            tag="1000_{}".format("img" if write_emb_img else "no_img"),
        )

    # Save confusion matrices (is expensive to calculate them afterwards)
    if save_cms:
        print("Saving confusion matrices...")
        # Assure folder
        cms_dir = os.path.join(base_dir, "cms", experiment_mode)
        os.makedirs(cms_dir, exist_ok=True)
        base_fname = os.path.join(cms_dir, run_name)

        n_diseases = len(chosen_diseases)

        def extract_cms(metrics):
            """Extract confusion matrices from a metrics dict."""
            cms = []
            for disease in chosen_diseases:
                key = "cm_" + disease
                if key not in metrics:
                    cm = np.array([[-1, -1], [-1, -1]])
                else:
                    cm = metrics[key].numpy()

                cms.append(cm)
            return np.array(cms)

        # Train confusion matrix
        train_cms = extract_cms(trainer.state.metrics)
        np.save(base_fname + "_train", train_cms)
        tb_write_cms(writer, "train", chosen_diseases, train_cms)

        # Validation confusion matrix
        val_cms = extract_cms(validator.state.metrics)
        np.save(base_fname + "_val", val_cms)
        tb_write_cms(writer, "val", chosen_diseases, val_cms)

        # All confusion matrix (train + val)
        all_cms = train_cms + val_cms
        np.save(base_fname + "_all", all_cms)

        # Print to console
        if len(chosen_diseases) == 1:
            print("Train CM: ")
            print(train_cms[0])
            print("Val CM: ")
            print(val_cms[0])


#             print("Train CM 2: ")
#             print(trainer.state.metrics["cm_" + chosen_diseases[0]])
#             print("Val CM 2: ")
#             print(validator.state.metrics["cm_" + chosen_diseases[0]])

    if write_img:
        # NOTE: this option is not recommended, use Testing notebook to plot and analyze images

        print("Writing images to TB...")

        test_dataset, test_dataloader = utilsT.prepare_data(
            dataset_dir,
            "test",
            chosen_diseases,
            batch_size,
            max_images=test_max_images,
        )

        # TODO: add a way to select images?
        # image_list = list(test_dataset.label_index["FileName"])[:3]

        # Examples in test_dataset (with bboxes available):
        image_list = [
            # "00010277_000.png", # (Effusion, Infiltrate, Mass, Pneumonia)
            # "00018427_004.png", # (Atelectasis, Effusion, Mass)
            # "00021703_001.png", # (Atelectasis, Effusion, Infiltrate)
            # "00028640_008.png", # (Effusion, Infiltrate)
            # "00019124_104.png", # (Pneumothorax)
            # "00019124_090.png", # (Nodule)
            # "00020318_007.png", # (Pneumothorax)
            "00000003_000.png",  # (0)
            # "00000003_001.png", # (0)
            # "00000003_002.png", # (0)
            "00000732_005.png",  # (Cardiomegaly, Pneumothorax)
            # "00012261_001.png", # (Cardiomegaly, Pneumonia)
            # "00013249_033.png", # (Cardiomegaly, Pneumonia)
            # "00029808_003.png", # (Cardiomegaly, Pneumonia)
            # "00022215_012.png", # (Cardiomegaly, Pneumonia)
            # "00011402_007.png", # (Cardiomegaly, Pneumonia)
            # "00019018_007.png", # (Cardiomegaly, Infiltrate)
            # "00021009_001.png", # (Cardiomegaly, Infiltrate)
            # "00013670_151.png", # (Cardiomegaly, Infiltrate)
            # "00005066_030.png", # (Cardiomegaly, Infiltrate, Effusion)
            "00012288_000.png",  # (Cardiomegaly)
            "00008399_007.png",  # (Cardiomegaly)
            "00005532_000.png",  # (Cardiomegaly)
            "00005532_014.png",  # (Cardiomegaly)
            "00005532_016.png",  # (Cardiomegaly)
            "00005827_000.png",  # (Cardiomegaly)
            # "00006912_007.png", # (Cardiomegaly)
            # "00007037_000.png", # (Cardiomegaly)
            # "00007043_000.png", # (Cardiomegaly)
            # "00012741_004.png", # (Cardiomegaly)
            # "00007551_020.png", # (Cardiomegaly)
            # "00007735_040.png", # (Cardiomegaly)
            # "00008339_010.png", # (Cardiomegaly)
            # "00008365_000.png", # (Cardiomegaly)
            # "00012686_003.png", # (Cardiomegaly)
        ]

        tb_write_images(writer, model, test_dataset, chosen_diseases, n_epochs,
                        device, image_list)

    # Close TB writer
    if experiment_mode != "debug":
        writer.close()

    # Run post_train
    print("-" * 50)
    print("Running post_train...")

    print("Loading test dataset...")
    test_dataset, test_dataloader = utilsT.prepare_data(
        dataset_dir,
        "test",
        chosen_diseases,
        batch_size,
        max_images=test_max_images)

    save_cms_with_names(run_name, experiment_mode, model, test_dataset,
                        test_dataloader, chosen_diseases)

    evaluate_model(run_name,
                   model,
                   optimizer,
                   device,
                   loss_name,
                   loss_params,
                   chosen_diseases,
                   test_dataloader,
                   experiment_mode=experiment_mode,
                   base_dir=base_dir)

    # Return values for debugging
    model_run = ModelRun(model, run_name, model_name, chosen_diseases)
    if experiment_mode == "debug":
        model_run.save_debug_data(writer, trainer, validator, train_dataset,
                                  train_dataloader, val_dataset,
                                  val_dataloader)

    return model_run