def plot_mwn_function(args, meta_weight_net_state, close=True):
    linspace_losses = torch.linspace(start=0, end=20, steps=100).to(args.cuda)
    temp_MWN = MLP().to(args.cuda)
    temp_MWN.load_state_dict(meta_weight_net_state)
    temp_MWN.eval()
    predicted_weights = temp_MWN(linspace_losses.reshape(-1, 1)).data.cpu()

    cifar_type = args.cifar_type if args.dataset == 'CIFAR' else ""
    plt.figure()
    plt.title('{}{}-{}-{}'.format(args.dataset, cifar_type,
                                  args.experiment_type, args.factor))
    plt.ylabel('weight')
    plt.xlabel('loss')
    plt.plot(linspace_losses.cpu(),
             predicted_weights,
             label='{}{}'.format(args.dataset, cifar_type))
    plt.legend(loc='best')
    plt.tight_layout()
    mng = plt.get_current_fig_manager()
    mng.full_screen_toggle()
    plt.savefig(os.path.join(args.directory, 'MWN_function_plot.png'))
    if close: plt.close()
Beispiel #2
0
                               random_state=1)
trainer = Data(train, features)
tester = Data(test, features)
'''model setup'''
model = MLP(features, num_class, params)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

print(f'The Model:\n{model}')
'''train the model'''
train_loss, train_acc, test_loss, test_acc =\
  model_train(model,
     trainer,
     optimizer,
     criterion,
     tester=tester,
     batch_size=10,
     epochs=epochs)
'''plot the performance'''
performance_plot(train_loss, test_loss, 0.5, "loss", "Loss.jpeg")
performance_plot(train_acc, test_acc, 0.5, "accuracy", "Accuracy.jpeg")
'''test output'''
xs, ys = tester[:]
model.eval()
yhats = model(xs).softmax(axis=1).argmax(axis=1)
yhats = yhats.detach().numpy()
df = pd.DataFrame({
    "True labels": ys,
    "Predicted labels": yhats
}).to_string(index=False)
print(f'\nTest Output: \n{df}')
Beispiel #3
0
def train_mlp(options, X_train, X_test, y_train, y_test):

    exp_name = options['exp_name']
    batch_size = options['batch_size']
    use_pca = options['use_pca']
    model_type = options['model_type']
    loss_fn = options['loss_fn']
    optim = options['optim']
    use_scheduler = options['use_scheduler']
    lr = options['lr']
    epochs = options['epochs']
    pca_var_hold = options['pca_var_hold']
    debug_mode = options['debug_mode']
    win_size = options['win_size']
    if exp_name is None:

        exp_name = 'runs/Raw_' + str(model_type) + '_pca_' + str(
            use_pca) + str(
                round(pca_var_hold)) + '_' + str(batch_size) + '_' + str(
                    round(lr, 2)) + '_win' + str(win_size) + '_transf' + str(
                        options['transform_targets'])
    if os.path.exists(exp_name):
        shutil.rmtree(exp_name)

    # time.sleep(1)
    writer = SummaryWriter(exp_name, flush_secs=1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    TRANSF = Transform(y_train)

    if options['transform_targets']:
        y_train = TRANSF.fit(y_train)
        y_test = TRANSF.fit(y_test)
    if use_pca and 'Raw' in exp_name:
        scaler = PCA(pca_var_hold)
        scaler.fit(X_train)
        X_train = scaler.transform(X_train)
        X_test = scaler.transform(X_test)

    needed_dim = X_train.shape[1]

    dataset_train = MOOD(X_train,
                         y_train,
                         model_type=model_type,
                         data_type='train',
                         debug_mode=debug_mode)
    train_loader = DataLoader(dataset=dataset_train,
                              batch_size=batch_size,
                              shuffle=True)

    dataset_val = MOOD(X_test, y_test, model_type=model_type, data_type='val')
    valid_loader = DataLoader(dataset=dataset_val,
                              batch_size=batch_size,
                              shuffle=False)

    model = MLP(needed_dim=needed_dim, model_type=model_type, n_classes=None)
    model.to(device)
    if optim == None:
        print('you need to specify an optimizer')
        exit()
    elif optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', verbose=True, threshold=0.0001, patience=10)
    if loss_fn == None:
        print('you need to specify an optimizer')
        exit()
    else:

        if loss_fn == 'mse':

            loss_fn = torch.nn.MSELoss()
        elif loss_fn == 'cross_entropy':
            loss_fn = torch.nn.CrossEntropyLoss()

    mean_train_losses = []
    mean_valid_losses = []
    valid_acc_list = []
    best = 0  #small number for acc big number for loss to save a model

    for epoch in range(epochs):
        model.train()
        train_losses = []
        valid_losses = []
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            # print(images.shape)
            outputs = model(images)

            loss = loss_fn(outputs, labels)
            # print('loss: ',loss.item())
            writer.add_scalar('Loss/train', loss.item(),
                              len(train_loader) * epoch + i)

            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            del outputs
            # if (i * batch_size) % (batch_size * 100) == 0:
            #     print(f'{i * batch_size} / 50000')

        model.eval()
        correct_5_2 = 0
        correct_5_1 = 0

        total_loss = 0
        total = 0
        accsat = [1, 0.5, 0.05]
        accs = np.zeros(len(accsat))
        # corrs = np.zeros(len(accsat))
        correct_array = np.zeros(len(accsat))
        with torch.no_grad():
            for i, (images, labels) in enumerate(valid_loader):
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = loss_fn(outputs, labels)
                """
                Correct if:
                preprocess.decode_target(output) == proprocess.decode(target)
                
                """
                for i in range(len(accsat)):

                    correct_array[i] += accat(
                        outputs,
                        labels,
                        thresh=accsat[i],
                        preprocess_instance=TRANSF,
                        transform_targets=options['transform_targets'])

                # total_loss += loss.item()
                total += labels.size(0)

                valid_losses.append(loss.item())

        mean_train_losses.append(np.mean(train_losses))
        mean_valid_losses.append(np.mean(valid_losses))
        # scheduler.step(np.mean(valid_losses))
        for i in range(len(accsat)):
            accs[i] = 100 * correct_array[i] / total
            writer.add_scalar('Acc/val_@' + str(accsat[i]), accs[i], epoch)

        if float(accs[1] / 100) > best:
            best = float(accs[1] / 100)
            torch.save(model.state_dict(),
                       os.path.join(os.getcwd(), 'models', 'meh.pth'))

        writer.add_scalar('Loss/val', np.mean(valid_losses), epoch)
        # valid_acc_list.append(accuracy)
    return best, np.mean(valid_losses)
            # update only the submodels (x,) for x being a participant
            subset_weights = update_weights_from_gradients(
                gradients[i - 1], submodel_dict[(i, )].state_dict())
            submodel_dict[(i, )].load_state_dict(subset_weights)

        totalRunTime += time.time() - start_time
        ######## Timing ends ########

        # update global weights
        global_model.load_state_dict(
            global_weights)  ### update the 2^n submodels as well

        # Calculate avg training accuracy over all users at every epoch.
        # For this case, since all users are participating in training, we need to adjust the code
        # list_acc, list_loss = [], []
        global_model.eval()
        # for c in range(args.num_users): (this doesn't apply in our case)
        # for idx in idxs_users:
        #     local_model = LocalUpdate(args=args, dataset=train_dataset[idx],
        #                               idxs=user_groups[idx], logger=logger)
        #     acc, loss = local_model.inference(model=global_model)
        #     list_acc.append(acc)
        #     list_loss.append(loss)
        # train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            # print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
Beispiel #5
0
def main(args):

    ts = datetime.datetime.now().timestamp()

    logger = SummaryWriter(
        os.path.join('exp/qgen_rl/', '{}_{}'.format(args.exp_name, ts)))
    logger.add_text('exp_name', args.exp_name)
    logger.add_text('args', str(args))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    vocab = Vocab(os.path.join(args.data_dir, 'vocab.csv'), 3)
    category_vocab = CategoryVocab(
        os.path.join(args.data_dir, 'categories.csv'))

    data_loader = OrderedDict()
    splits = ['train', 'valid'] + (['test'] if args.test_set else list())
    for split in splits:
        file = os.path.join(args.data_dir, 'guesswhat.' + split + '.jsonl.gz')
        data_loader[split] = DataLoader(
            dataset=InferenceDataset(split,
                                     file,
                                     vocab,
                                     category_vocab,
                                     new_object=split == 'train',
                                     load_vgg_features=True),
            batch_size=args.batch_size,
            collate_fn=InferenceDataset.get_collate_fn(device),
            shuffle=split == 'train')

    if not args.belief_state:
        qgen = QGen.load(device, file=args.qgen_file)
    else:
        qgen = QGenBelief.load(device, file=args.qgen_file)
    guesser = Guesser.load(device, file=args.guesser_file)
    oracle = Oracle.load(device, file=args.oracle_file)

    generation_wrapper = GenerationWrapper(qgen, guesser, oracle)

    baseline = MLP(
        sizes=[qgen.hidden_size, args.baseline_hidden_size, 1],
        activation='relu', final_activation='relu', bias=[True, False])\
        .to(device)

    baseline_loss_fn = torch.nn.MSELoss(reduction='sum')
    baseline_optimizer = Optimizer(torch.optim.SGD,
                                   baseline.parameters(),
                                   lr=args.baseline_lr)
    qgen_optimizer = Optimizer(torch.optim.SGD,
                               qgen.parameters(),
                               lr=args.qgen_lr)

    split2strat = {
        'train': args.train_strategy,
        'valid': args.eval_strategy,
        'test': args.eval_strategy
    }

    best_val_acc = 0
    for epoch in range(args.epochs):

        for split in splits:

            if split == 'train':
                qgen.train()
                baseline.train()
                torch.enable_grad()
            else:
                qgen.eval()
                baseline.eval()
                torch.no_grad()

            total_acc = list()
            for iteration, sample in enumerate(data_loader[split]):

                return_dict = generation_wrapper.generate(
                    sample,
                    vocab,
                    split2strat[split],
                    args.max_num_questions,
                    device,
                    args.belief_state,
                    return_keys=[
                        'mask', 'object_logits', 'hidden_states', 'log_probs',
                        'generations'
                    ])

                mask = return_dict['mask']
                object_logits = return_dict['object_logits']
                hidden_states = return_dict['hidden_states']
                log_probs = return_dict['log_probs']

                acc = accuarcy(object_logits, sample['target_id'])
                total_acc += [acc]

                mask = mask.float()

                rewards = torch.eq(
                    object_logits.topk(1)[1].view(-1),
                    sample['target_id'].view(-1)).float()
                rewards = rewards.unsqueeze(1).repeat(1, mask.size(1))
                rewards *= mask

                print("dialogue", return_dict['dialogue'][0],
                      return_dict['dialogue'].size())
                #print("log_probs", log_probs, log_probs.size())
                #print("mask", mask, mask.size())
                #print("rewards", rewards, rewards.size())

                baseline_preds = baseline(hidden_states.detach_()).squeeze(2)
                baseline_preds *= mask
                baseline_loss = baseline_loss_fn(
                    baseline_preds.view(-1), rewards.view(-1)) \
                    / baseline_preds.size(0)

                log_probs *= mask
                baseline_preds = baseline_preds.detach()
                policy_gradient_loss = torch.sum(log_probs *
                                                 (rewards - baseline_preds),
                                                 dim=1)
                print(policy_gradient_loss)
                policy_gradient_loss = -torch.mean(policy_gradient_loss)
                print()
                raise
                # policy_gradient_loss = - torch.sum(log_probs) / torch.sum(mask)
                #print(policy_gradient_loss_old.item(), policy_gradient_loss.item())

                if split == 'train':
                    qgen_optimizer.optimize(policy_gradient_loss,
                                            clip_norm_args=[args.clip_value])
                    baseline_optimizer.optimize(
                        baseline_loss, clip_norm_args=[args.clip_value])

                logger.add_scalar('{}_accuracy'.format(split), acc,
                                  iteration + len(data_loader[split]) * epoch)

                logger.add_scalar('{}_reward'.format(split),
                                  torch.mean(rewards).item(),
                                  iteration + len(data_loader[split]) * epoch)

                logger.add_scalar('{}_bl_loss'.format(split),
                                  baseline_loss.item(),
                                  iteration + len(data_loader[split]) * epoch)

                logger.add_scalar('{}_pg_loss'.format(split),
                                  policy_gradient_loss.item(),
                                  iteration + len(data_loader[split]) * epoch)

            model_saved = False
            if split == 'valid':
                if np.mean(total_acc) > best_val_acc:
                    best_val_acc = np.mean(total_acc)
                    qgen.save(file='bin/qgen_rl_{}_{}.pt'.format(
                        args.exp_name, ts),
                              accuarcy=np.mean(total_acc))
                    model_saved = True

            logger.add_scalar('epoch_{}_accuracy'.format(split),
                              np.mean(total_acc), epoch)

            print("Epoch {:3d}: {} Accuracy {:5.3f} {}".format(
                epoch, split.upper(),
                np.mean(total_acc) * 100, '*' if model_saved else ''))
        print("-" * 50)
Beispiel #6
0
def train_MLP(train_X, train_Y, test_X, test_Y, batch_size=20, epochs=100):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

    model = MLP(10, 20, 2)
    model.cuda()
    model.train()

    learn_rate = 1e-3
    grad_clip = 2.0
    dispFreq = 50
    validFreq = 200
    early_stop = 20
    weight = torch.FloatTensor([2.0, 1.0])
    loss_function = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=learn_rate)
    params = filter(lambda p: p.requires_grad, model.parameters())

    dev_tensor = Variable(torch.FloatTensor(test_X).cuda())

    curr = 0
    uidx = 0
    # For Early-stopping

    best_step = 0
    for iepx in xrange(1, epochs + 1):
        for ibx in xrange(0, len(train_X), batch_size):
            if ibx + batch_size >= len(train_X):
                batch = Variable(
                    torch.FloatTensor(train_X[ibx:len(train_X)]).cuda())
                target = Variable(
                    torch.LongTensor(train_Y[ibx:len(train_X)]).cuda())
            else:
                batch = Variable(
                    torch.FloatTensor(train_X[ibx:ibx + batch_size]).cuda())
                target = Variable(
                    torch.LongTensor(train_Y[ibx:ibx + batch_size]).cuda())

            uidx += 1

            pred = model(batch)

            loss = loss_function(pred, target)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(params, grad_clip)
            optimizer.step()

            if np.mod(uidx, dispFreq) == 0:
                print 'Epoch ', iepx, '\tUpdate ', uidx, '\tCost ', loss.data.cpu(
                ).numpy()[0]

            if np.mod(uidx, validFreq) == 0:
                # compute dev
                model.eval()
                out = model.forward(dev_tensor)
                model.train()
                # score = nn.NLLLoss(weight=weight)(out, vs_tensor).data[0]
                pred = categoryFromOutput(out)
                F1 = f1_score(test_Y, pred)

                curr_step = uidx / validFreq

                currscore = F1

                print 'F1 on dev', F1

                if currscore > curr:
                    curr = currscore
                    best_step = curr_step

                    # Save model
                    print 'Saving model...',
                    # torch.save(model.state_dict(), '%s_model_%s.pkl' % (saveto, run))
                    print 'Done'

                if curr_step - best_step > early_stop:
                    print 'Early stopping ...'
                    print best_step
                    print curr
                    return
Beispiel #7
0
class Model(object):
    def __init__(self, args):
        self.args = args
        self.model = MLP(args.nin, args.nh, args.nout, args.do)
        self.model.to(args.d)
        print("model params {}".format(count_parameters(self.model)))
        log_path = "logs/gmm"
        if os.path.exists(log_path) and os.path.isdir(log_path):
            shutil.rmtree(log_path)
        self.writer = SummaryWriter(log_path)
        nc = 2 if args.dataset == "toy" else 3
        self.best_loss = np.inf
        self.softmax = torch.nn.Softmax(dim=1)
        self.ce = torch.nn.CrossEntropyLoss()

    def compute_gmm_loss(self, z, logits):
        gamma = self.softmax(logits)
        phi, mu, cov = compute_params(z, gamma)
        sample_energy = compute_energy(z,
                                       phi=phi,
                                       mu=mu,
                                       cov=cov,
                                       size_average=True)
        return sample_energy

    def train_model(self, data_loader):
        self.model.train()
        train_loss = 0
        train_ce_loss = 0
        train_gmm_loss = 0
        for idx, batch in tqdm(enumerate(data_loader)):
            self.optimizer.zero_grad()
            x, y = batch
            x = x.to(self.args.d)
            y = y.to(self.args.d)
            logits = self.model(x)
            ce_loss = self.ce(logits, y)
            gmm_loss = self.compute_gmm_loss(x, logits)
            loss = ce_loss + self.args.alpha * gmm_loss
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            train_ce_loss += ce_loss.item()
            train_gmm_loss += gmm_loss.item()
        train_loss /= (idx + 1)
        train_ce_loss /= (idx + 1)
        train_gmm_loss /= (idx + 1)
        return train_loss, train_ce_loss, train_gmm_loss

    def eval_model(self, data_loader):
        val_loss = 0
        val_ce_loss = 0
        val_gmm_loss = 0
        self.model.eval()
        with torch.no_grad():
            for idx, batch in tqdm(enumerate(data_loader)):
                x, y = batch
                x = x.to(self.args.d)
                y = y.to(self.args.d)
                logits = self.model(x)
                ce_loss = self.ce(logits, y)
                gmm_loss = self.compute_gmm_loss(x, logits)
                loss = ce_loss + self.args.alpha * gmm_loss
                val_loss += loss.item()
                val_ce_loss += ce_loss.item()
                val_gmm_loss += gmm_loss.item()
        val_loss /= (idx + 1)
        val_ce_loss /= (idx + 1)
        val_gmm_loss /= (idx + 1)
        return val_loss, val_ce_loss, val_gmm_loss

    def fit(self, train_loader, val_loader):
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.args.lr,
                                          weight_decay=self.args.wd,
                                          amsgrad=True)
        loss = np.zeros((self.args.e, 2, 3))
        counter = 0
        for epoch in range(self.args.e):
            template = "Epoch {} train loss {:.4f} val loss {:.4f}"
            train_loss, train_ce_loss, train_gmm_loss = self.train_model(
                train_loader)
            val_loss, val_ce_loss, val_gmm_loss = self.eval_model(val_loader)
            loss[epoch, 0] = (train_loss, train_ce_loss, train_gmm_loss)
            loss[epoch, 1] = (val_loss, val_ce_loss, val_gmm_loss)
            #self.writer.add_scalars("total", {"train": train_loss, "val": val_loss}, global_step=epoch)
            print(template.format(epoch, train_loss, val_loss))
            if val_loss < self.best_loss:
                self.best_model = self.model.state_dict()
                self.best_loss = val_loss
                counter = 0
            else:
                counter += 1
            if counter == self.args.patience:
                break
        loss = loss[:epoch + 1]
        self.last_model = self.model.state_dict()
        torch.save(self.best_model,
                   "models/{}/gmm_best.pth".format(self.args.dataset))
        torch.save(self.last_model,
                   "models/{}/gmm_last.pth".format(self.args.dataset))
        return loss

    def compute_sample_energy(self, x):
        self.model.load_state_dict(self.last_model)
        self.model.eval()
        with torch.no_grad():
            x = torch.tensor(x, dtype=torch.float, device=self.args.d)
            logits = self.model(x)
            gamma = self.softmax(logits)
            phi, mu, cov = compute_params(x, gamma)
            sample_energy = compute_energy(x,
                                           phi=phi,
                                           mu=mu,
                                           cov=cov,
                                           size_average=False)
        return sample_energy.cpu().numpy(), logits.cpu().numpy()

    def evaluate(self, x, y, outlier_class):
        sample_energy, logits = self.compute_sample_energy(x)
        y_pred = logits.argmax(1)

        nclasses = len(np.unique(y))
        target = np.ones(len(logits))
        for i in outlier_class:
            target[y == i] = 0

        # precision recall
        precision, recall, _ = precision_recall_curve(target,
                                                      sample_energy,
                                                      pos_label=0)
        aucpr = auc(recall, precision)

        fpr, tpr, _ = roc_curve(target, sample_energy, pos_label=0)
        aucroc = auc(fpr, tpr)
        cm_multi = confusion_matrix(y, y_pred)

        return aucpr, aucroc, cm_multi

    def save_features(self, x, y, savename):
        np.save("{}_x.npy".format(savename), x.cpu().numpy())
        np.save("{}_y.npy".format(savename), y.cpu().numpy())
        return
    def objective(trial: optuna.Trial,  # with optuna
                  lr: int = None, output_dims: List = None, dropout: float = None  # without optuna
                  ):
        assert not (trial is not None and lr is not None)
        assert not (trial is not None and output_dims is not None)
        assert not (trial is not None and dropout is not None)
        assert not (trial is None and lr is None)
        assert not (trial is None and output_dims is None)
        assert not (trial is None and dropout is None)

        global model

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)

        if trial is not None:
            if not isinstance(trial, optuna.Trial):
                # Best
                lr = trial.params['lr']
                nlayers = trial.params['nlayers']
                dropouts = trial.params['dropout']
                output_dims = [trial.params[f'n_units_l{i}'] for i in range(nlayers)]
            else:
                # In study.
                logger(f'{"-" * 10} Trial #{trial.number} {"-" * 10}')

                # optuna settings
                # lr = trial.suggest_uniform('lr', lr_lower_bound, lr_upper_bound)
                lr = trial.suggest_categorical('lr', [1e-3, 3e-4, 2e-4, 1e-5])
                nlayers = trial.suggest_int('nlayers', nlayers_lower_bound, nlayers_upper_bound)
                dropouts = [
                        trial.suggest_categorical(f'dropout_l{i}', [0.2, 0.5, 0.7])
                        for i in range(2)
                        ]
                output_dims = [
                        int(trial.suggest_categorical(f'n_units_l{i}', list(range(odim_start, odim_end, odim_step))))
                        for i in range(nlayers)
                        ]
        else:
            nlayers = len(output_dims)

        logger('Setting up models...')
        device = torch.device('cuda' if use_cuda else 'cpu')
        model = MLP(nlayers, dropouts, output_dims).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criteria = nn.CrossEntropyLoss(weight=loss_weight.to(device))

        best_acc = 0
        n_fail_in_a_raw = 0
        limit_n_fail_in_a_raw = 5

        # print('Start training...')
        for i_epoch in range(1, epoch+1):
            losses = []
            model.train()
            for tgts, sent1s, sent2s in train_dataloader:
                tgts = tgts.to(device)
                sent1s = sent1s.to(device)
                sent2s = sent2s.to(device)

                preds = model(sent1s, sent2s)
                loss = criteria(preds, tgts)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())

            model.eval()
            valid_losses = []
            valid_accs = []
            with torch.no_grad():
                for tgts, sent1s, sent2s in valid_dataloader:
                    tgts = tgts.to(device)
                    sent1s = sent1s.to(device)
                    sent2s = sent2s.to(device)

                    preds = model(sent1s, sent2s)
                    pred_idxs = preds.argmax(dim=1).tolist()

                    loss = criteria(preds, tgts)
                    acc = len([1 for p, t in zip(pred_idxs, tgts.tolist()) if p == t]) / len(tgts.tolist())
                    valid_losses.append(loss.item())
                    valid_accs.append(acc)

            logger(f'Train loss: {np.mean(losses)}')

            _loss = np.mean(valid_losses)
            _acc = np.mean(valid_accs)
            logger(f'Valid loss: {_loss}')
            logger(f'Valid accuracy: {_acc}')

            if _acc > best_acc:
                best_acc = _acc
                n_fail_in_a_raw = 0
            else:
                n_fail_in_a_raw += 1

            if n_fail_in_a_raw >= limit_n_fail_in_a_raw:
                break

            logger(f"{'-' * 25}\n")

        return best_acc
def train_model(args, use_cuda=False):
    start_time = time.time()

    # Read values from args
    num_tasks = args.num_tasks
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    lr = args.lr
    num_epochs = args.num_epochs
    num_points = args.num_points
    coreset_select_method = args.select_method

    # Some parameters
    dataset_generation_test = False
    dataset_num_samples = 2000

    # Colours for plotting
    color = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']

    # Load / Generate toy data
    datagen = ToydataGenerator(max_iter=num_tasks,
                               num_samples=dataset_num_samples)

    plt.figure()
    datagen.reset()
    total_loaders = []
    criterion_cl = nn.CrossEntropyLoss()

    # Create model
    layer_size = [2, hidden_size, hidden_size, 2]
    model = MLP(layer_size, act='sigmoid')
    if use_cuda:
        model = model.cuda()

    # Optimiser
    opt = opt_fromp(model,
                    lr=lr,
                    prior_prec=1e-4,
                    grad_clip_norm=None,
                    tau=args.tau)

    memorable_points = None
    inducing_targets = None

    for tid in range(num_tasks):
        # If not first task, need to calculate and store regularisation-term-related quantities
        if tid > 0:

            def closure(task_id):
                memorable_points_t = memorable_points[task_id]
                if use_cuda:
                    memorable_points_t = memorable_points_t.cuda()
                opt.zero_grad()
                logits = model.forward(memorable_points_t)
                return logits

            opt.init_task(closure, tid, eps=1e-3)

        # Data generator for this task
        itrain, itest = datagen.next_task()
        itrainloader = DataLoader(dataset=itrain,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=8)
        itestloader = DataLoader(dataset=itest,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=8)
        inducingloader = DataLoader(dataset=itrain,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=8)
        iloaders = [itrainloader, itestloader]

        if tid == 0:
            total_loaders = [itrainloader]
        else:
            total_loaders.append(itrainloader)

        # Train and test
        cl_outputs = train(model,
                           iloaders,
                           memorable_points,
                           criterion_cl,
                           opt,
                           task_id=tid,
                           num_epochs=num_epochs,
                           use_cuda=use_cuda)

        # Select memorable past datapoints
        if coreset_select_method == 'random':
            i_memorable_points, i_inducing_targets = random_memorable_points(
                itrain, num_points=num_points, num_classes=2)
        else:
            i_memorable_points, i_inducing_targets = select_memorable_points(
                inducingloader,
                model,
                num_points=num_points,
                use_cuda=use_cuda)

        # Add memory points to set
        if tid > 0:
            memorable_points.append(i_memorable_points)
            inducing_targets.append(i_inducing_targets)
        else:
            memorable_points = [i_memorable_points]
            inducing_targets = [i_inducing_targets]

        # Update covariance (\Sigma)
        update_fisher(inducingloader, model, opt, use_cuda=use_cuda)

        # Plot visualisation (2D figure)
        cl_outputs, _ = torch.max(cl_outputs, dim=-1)
        cl_show = 2 * cl_outputs - 1

        cl_show = cl_show.detach()
        if use_cuda:
            cl_show = cl_show.cpu()
        cl_show = cl_show.numpy()
        cl_show = cl_show.reshape(datagen.test_shape)

        plt.figure()
        axs = plt.subplot(111)
        axs.title.set_text('FROMP')
        if not dataset_generation_test:
            plt.imshow(cl_show,
                       cmap='gray',
                       extent=(datagen.x_min, datagen.x_max, datagen.y_min,
                               datagen.y_max),
                       origin='lower')
        for l in range(tid + 1):
            idx = np.where(datagen.y == l)
            plt.scatter(datagen.X[idx][:, 0],
                        datagen.X[idx][:, 1],
                        c=color[l],
                        s=0.03)
            idx = np.where(datagen.y == l + datagen.offset)
            plt.scatter(datagen.X[idx][:, 0],
                        datagen.X[idx][:, 1],
                        c=color[l + datagen.offset],
                        s=0.03)
            if not dataset_generation_test:
                plt.scatter(memorable_points[l][:, 0],
                            memorable_points[l][:, 1],
                            c='m',
                            s=0.4,
                            marker='x')

        plt.show()

        # Calculate and print train accuracy and negative log likelihood
        with torch.no_grad():
            if not dataset_generation_test:
                model.eval()
                N = len(itrain)

                metric_task_id = 0
                nll_loss_avg = 0
                accuracy_avg = 0
                for metric_loader in total_loaders:
                    nll_loss = 0
                    correct = 0
                    for inputs, labels in metric_loader:
                        if use_cuda:
                            inputs, labels = inputs.cuda(), labels.cuda()

                        logits = model.forward(inputs)

                        nll_loss += nn.functional.cross_entropy(
                            torch.squeeze(logits, dim=-1), labels) * float(
                                inputs.shape[0])

                        # Calculate predicted classes
                        pred = logits.data.max(1, keepdim=True)[1]

                        # Count number of correctly predicted datapoints
                        correct += pred.eq(labels.data.view_as(pred)).sum()

                    nll_loss /= N
                    accuracy = float(correct) / float(N) * 100.

                    print(
                        'Task {}, Train accuracy: {:.2f}%, Train Loglik: {:.4f}'
                        .format(metric_task_id, accuracy, nll_loss))

                    metric_task_id += 1
                    nll_loss_avg += nll_loss
                    accuracy_avg += accuracy

                print('Avg train accuracy: {:.2f}%, Avg train Loglik: {:.4f}'.
                      format(accuracy_avg / metric_task_id,
                             nll_loss_avg / metric_task_id))

    print('Time taken: ', time.time() - start_time)