Exemple #1
0
    def end_task(self, dataset):
        # reinit network
        self.net = dataset.get_backbone()
        self.net.to(self.device)
        self.net.train()
        self.opt = SGD(self.net.parameters(), lr=self.args.lr)

        # gather data
        all_data = torch.cat(self.old_data)
        all_labels = torch.cat(self.old_labels)

        # train
        for e in range(1):  #range(self.args.n_epochs):
            rp = torch.randperm(len(all_data))
            for i in range(math.ceil(len(all_data) / self.args.batch_size)):
                inputs = all_data[rp][i * self.args.batch_size:(i + 1) *
                                      self.args.batch_size]
                labels = all_labels[rp][i * self.args.batch_size:(i + 1) *
                                        self.args.batch_size]
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                self.opt.zero_grad()
                outputs = self.net(inputs)
                loss = self.loss(outputs, labels.long())
                loss.backward()
                self.opt.step()
                progress_bar(i,
                             math.ceil(len(all_data) / self.args.batch_size),
                             e, 'J', loss.item())
def train(args: Namespace):
    """
    The training process, including evaluations and loggers.
    :param model: the module to be trained
    :param dataset: the continual dataset at hand
    :param args: the arguments of the current execution
    """
    if args.csv_log:
        from utils.loggers import CsvLogger

    dataset = get_gcl_dataset(args)
    backbone = dataset.get_backbone()
    loss = dataset.get_loss()
    model = get_model(args, backbone, loss, dataset.get_transform())
    model.net.to(model.device)

    model_stash = create_fake_stash(model, args)

    if args.csv_log:
        csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME)
    if args.tensorboard:
        tb_logger = TensorboardLogger(args, dataset.SETTING, model_stash)

    model.net.train()
    epoch, i = 0, 0
    while not dataset.train_over:
        inputs, labels, not_aug_inputs = dataset.get_train_data()
        inputs, labels = inputs.to(model.device), labels.to(model.device)
        not_aug_inputs = not_aug_inputs.to(model.device)
        loss = model.observe(inputs, labels, not_aug_inputs)
        progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss)
        if args.tensorboard:
            tb_logger.log_loss_gcl(loss, i)
        i += 1

    if model.NAME == 'joint_gcl':
        model.end_task(dataset)

    acc = evaluate(model, dataset)
    print('Accuracy:', acc)

    if args.csv_log:
        csv_logger.log(acc)
        csv_logger.write(vars(args))
Exemple #3
0
    def end_task(self, dataset):
        if dataset.NAME == 'seq-core50':
            raise NotImplementedError('Do you hear the RAM crying at night?')
        else:
            self.old_data.append(dataset.train_loader.dataset.data)
            self.old_labels.append(torch.tensor(dataset.train_loader.dataset.targets))
            self.current_task += 1
            # # for non-incremental joint training
            if len(dataset.test_loaders) != dataset.N_TASKS: return

            # reinit network
            self.net = dataset.get_backbone()
            self.net.to(self.device)
            self.net.train()
            self.opt = SGD(self.net.parameters(), lr=self.args.lr)

            # prepare dataloader
            all_data, all_labels = None, None
            for i in range(len(self.old_data)):
                if all_data is None:
                    all_data = self.old_data[i]
                    all_labels = self.old_labels[i]
                else:
                    all_data = np.concatenate([all_data, self.old_data[i]])
                    all_labels = np.concatenate([all_labels, self.old_labels[i]])

            temp_dataset = ValidationDataset(all_data, all_labels, transform=transforms.ToTensor(
            ) if dataset.get_transform() is None else dataset.get_transform())
            loader = torch.utils.data.DataLoader(temp_dataset, batch_size=self.args.batch_size, shuffle=True)

            # train
            for e in range(self.args.n_epochs):
                for i, batch in enumerate(loader):
                    inputs, labels = batch
                    inputs, labels = inputs.to(self.device), labels.to(self.device)

                    self.opt.zero_grad()
                    outputs = self.net(inputs)
                    loss = self.loss(outputs, labels.long())
                    loss.backward()
                    self.opt.step()
                    progress_bar(i, len(loader), e, 'J', loss.item())
Exemple #4
0
    def end_task(self, dataset):
        if dataset.SETTING != 'domain-il':
            self.old_data.append(dataset.train_loader.dataset.data)
            self.old_labels.append(
                torch.tensor(dataset.train_loader.dataset.targets))
            self.current_task += 1

            # # for non-incremental joint training
            if len(dataset.test_loaders) != dataset.N_TASKS: return

            # reinit network
            self.net = dataset.get_backbone()
            self.net.to(self.device)
            self.net.train()
            self.opt = SGD(self.net.parameters(), lr=self.args.lr)

            # prepare dataloader
            all_data, all_labels = None, None
            for i in range(len(self.old_data)):
                if all_data is None:
                    all_data = self.old_data[i]
                    all_labels = self.old_labels[i]
                else:
                    all_data = np.concatenate([all_data, self.old_data[i]])
                    all_labels = np.concatenate(
                        [all_labels, self.old_labels[i]])

            temp_dataset = ValidationDataset(all_data,
                                             all_labels,
                                             transform=dataset.TRANSFORM)
            loader = torch.utils.data.DataLoader(
                temp_dataset, batch_size=self.args.batch_size, shuffle=True)

            # train
            for e in range(self.args.n_epochs):
                for i, batch in enumerate(loader):
                    inputs, labels = batch
                    inputs, labels = inputs.to(self.device), labels.to(
                        self.device)

                    self.opt.zero_grad()
                    outputs = self.net(inputs)
                    loss = self.loss(outputs, labels.long())
                    loss.backward()
                    self.opt.step()
                    progress_bar(i, len(loader), e, 'J', loss.item())
        else:
            self.old_data.append(dataset.train_loader)
            # train
            if len(dataset.test_loaders) != dataset.N_TASKS: return
            loader_caches = [[] for _ in range(len(self.old_data))]
            sources = torch.randint(5, (128, ))
            all_inputs = []
            all_labels = []
            for source in self.old_data:
                for x, l, _ in source:
                    all_inputs.append(x)
                    all_labels.append(l)
            all_inputs = torch.cat(all_inputs)
            all_labels = torch.cat(all_labels)
            bs = self.args.batch_size
            for e in range(self.args.n_epochs):
                order = torch.randperm(len(all_inputs))
                for i in range(int(math.ceil(len(all_inputs) / bs))):
                    inputs = all_inputs[order][i * bs:(i + 1) * bs]
                    labels = all_labels[order][i * bs:(i + 1) * bs]
                    inputs, labels = inputs.to(self.device), labels.to(
                        self.device)
                    self.opt.zero_grad()
                    outputs = self.net(inputs)
                    loss = self.loss(outputs, labels.long())
                    loss.backward()
                    self.opt.step()
                    progress_bar(i, int(math.ceil(len(all_inputs) / bs)), e,
                                 'J', loss.item())
Exemple #5
0
    def get_anchors(self, dataset):
        theta_t = self.net.get_params().detach().clone()
        self.spare_model.set_params(theta_t)

        # fine tune on memory buffer
        for _ in range(self.finetuning_epochs):
            inputs, labels = self.buffer.get_data(self.args.batch_size,
                                                  transform=self.transform)
            self.spare_opt.zero_grad()
            out = self.spare_model(inputs)
            loss = self.loss(out, labels)
            loss.backward()
            self.spare_opt.step()

        theta_m = self.spare_model.get_params().detach().clone()

        classes_for_this_task = np.unique(dataset.train_loader.dataset.targets)

        for a_class in classes_for_this_task:
            e_t = torch.rand(self.input_shape,
                             requires_grad=True,
                             device=self.device)
            e_t_opt = SGD([e_t], lr=self.args.lr)
            print(file=sys.stderr)
            for i in range(self.anchor_optimization_steps):
                e_t_opt.zero_grad()
                cum_loss = 0

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_m.detach().clone())
                loss = -torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_t.detach().clone())
                loss = torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                loss = torch.sum(self.gamma *
                                 (self.spare_model.features(e_t.unsqueeze(0)) -
                                  self.phi)**2)
                assert not self.phi.requires_grad
                loss.backward()
                cum_loss += loss.item()

                if i % 10 == 9:
                    progress_bar(i, self.anchor_optimization_steps, i,
                                 'A' + str(a_class), cum_loss)

                e_t_opt.step()

            e_t = e_t.detach()
            e_t.requires_grad = False
            self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0)))
            del e_t
            print('Total anchors:', len(self.anchors), file=sys.stderr)

        self.spare_model.zero_grad()
Exemple #6
0
def train(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace) -> None:
    """
    The training process, including evaluations and loggers.
    :param model: the module to be trained
    :param dataset: the continual dataset at hand
    :param args: the arguments of the current execution
    """
    model.net.to(model.device)
    results, results_mask_classes = [], []

    model_stash = create_stash(model, args, dataset)

    if args.csv_log:
        csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME)
    if args.tensorboard:
        tb_logger = TensorboardLogger(args, dataset.SETTING, model_stash)
        model_stash['tensorboard_name'] = tb_logger.get_name()

    dataset_copy = get_dataset(args)
    for t in range(dataset.N_TASKS):
        model.net.train()
        _, _ = dataset_copy.get_data_loaders()
    if model.NAME != 'icarl' and model.NAME != 'pnn':
        random_results_class, random_results_task = evaluate(
            model, dataset_copy)

    print(file=sys.stderr)
    for t in range(dataset.N_TASKS):
        model.net.train()
        train_loader, test_loader = dataset.get_data_loaders()
        if hasattr(model, 'begin_task'):
            model.begin_task(dataset)
        if t:
            accs = evaluate(model, dataset, last=True)
            results[t - 1] = results[t - 1] + accs[0]
            if dataset.SETTING == 'class-il':
                results_mask_classes[t -
                                     1] = results_mask_classes[t - 1] + accs[1]
        for epoch in range(args.n_epochs):
            for i, data in enumerate(train_loader):
                if hasattr(dataset.train_loader.dataset, 'logits'):
                    inputs, labels, not_aug_inputs, logits = data
                    inputs = inputs.to(model.device)
                    labels = labels.to(model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    logits = logits.to(model.device)
                    loss = model.observe(inputs, labels, not_aug_inputs,
                                         logits)
                else:
                    inputs, labels, not_aug_inputs = data
                    inputs, labels = inputs.to(model.device), labels.to(
                        model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    loss = model.observe(inputs, labels, not_aug_inputs)

                progress_bar(i, len(train_loader), epoch, t, loss)

                if args.tensorboard:
                    tb_logger.log_loss(loss, args, epoch, t, i)

                model_stash['batch_idx'] = i + 1
            model_stash['epoch_idx'] = epoch + 1
            model_stash['batch_idx'] = 0
        model_stash['task_idx'] = t + 1
        model_stash['epoch_idx'] = 0

        if hasattr(model, 'end_task'):
            model.end_task(dataset)

        accs = evaluate(model, dataset)
        results.append(accs[0])
        results_mask_classes.append(accs[1])

        mean_acc = np.mean(accs, axis=1)
        print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)

        model_stash['mean_accs'].append(mean_acc)
        if args.csv_log:
            csv_logger.log(mean_acc)
        if args.tensorboard:
            tb_logger.log_accuracy(np.array(accs), mean_acc, args, t)

    if args.csv_log:
        csv_logger.add_bwt(results, results_mask_classes)
        csv_logger.add_forgetting(results, results_mask_classes)
        if model.NAME != 'icarl' and model.NAME != 'pnn':
            csv_logger.add_fwt(results, random_results_class,
                               results_mask_classes, random_results_task)

    if args.tensorboard:
        tb_logger.close()
    if args.csv_log:
        csv_logger.write(vars(args))
def train(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace) -> None:
    """
    The training process, including evaluations and loggers.
    :param model: the module to be trained
    :param dataset: the continual dataset at hand
    :param args: the arguments of the current execution
    """
    last_wish = LastWish()
    model.net.to(model.device)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        get_par_fn = model.net.get_params
        feat_fn = model.net.features
        class_fn = model.net.classifier
        set_par_fn = model.net.set_params
        model.net = nn.DataParallel(model.net)
        setattr(model.net, 'get_params', get_par_fn)
        setattr(model.net, 'features', feat_fn)
        setattr(model.net, 'classifier', class_fn)
        setattr(model.net, 'set_params', set_par_fn)

    if args.checkpoint_path is not None:
        model_stash = load_backup(model, args)
    else:
        model_stash = create_stash(model, args, dataset)

    if args.csv_log:
        csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME)
    if args.tensorboard:
        tb_logger = TensorboardLogger(args, dataset.SETTING, model_stash)
        model_stash['tensorboard_name'] = tb_logger.get_name()

    last_wish.register_action(lambda: save_backup(model, model_stash))

    print(file=sys.stderr)
    train_loader, test_loader = dataset.get_joint_loaders()

    for epoch in range(args.n_epochs):
        if epoch < model_stash['epoch_idx']:
            print('skipping epoch', epoch, file=sys.stderr)
            continue
        for i, data in enumerate(train_loader):
            if epoch < model_stash['epoch_idx'] and i < model_stash['batch_idx']:
                print('batch', epoch, file=sys.stderr)
                continue

            inputs, labels, not_aug_inputs = data
            inputs, labels = inputs.to(model.device), labels.to(model.device)
            not_aug_inputs = not_aug_inputs.to(model.device)
            loss = model.observe(inputs, labels, not_aug_inputs)
            progress_bar(i, len(train_loader), epoch, 0, loss)

            if args.tensorboard:
                tb_logger.log_loss(loss, args, epoch, 0, i)

            model_stash['batch_idx'] = i + 1
        model_stash['epoch_idx'] = epoch + 1
        model_stash['batch_idx'] = 0

        if epoch and not epoch % 50 or epoch == args.n_epochs - 1:
            accs = evaluate(model, test_loader)
            print('\nAccuracy after {} epochs: {}'.format(epoch + 1, accs[0]))
            model_stash['mean_accs'].append(accs[0])
            if args.csv_log:
                csv_logger.log(np.array(accs))
            if args.tensorboard:
                tb_logger.log_accuracy(np.array(accs), np.array(accs), args, 0)

    if hasattr(model, 'end_task'):
        model.end_task(dataset)

    if args.tensorboard:
        tb_logger.close()
    if args.csv_log:
        csv_logger.write(vars(args))
Exemple #8
0
def train(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace) -> None:
    """
    The training process, including evaluations and loggers.
    :param model: the module to be trained
    :param dataset: the continual dataset at hand
    :param args: the arguments of the current execution
    """
    last_wish = LastWish()
    model.net.to(model.device)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        get_par_fn = model.net.get_params
        feat_fn = model.net.features
        class_fn = model.net.classifier
        set_par_fn = model.net.set_params
        model.net = nn.DataParallel(model.net)
        setattr(model.net, 'get_params', get_par_fn)
        setattr(model.net, 'features', feat_fn)
        setattr(model.net, 'classifier', class_fn)
        setattr(model.net, 'set_params', set_par_fn)

    if args.checkpoint_path is not None:
        model_stash = load_backup(model, args)
    else:
        model_stash = create_stash(model, args, dataset)

    if args.csv_log:
        csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME)
    if args.tensorboard:
        tb_logger = TensorboardLogger(args, dataset.SETTING, model_stash)
        model_stash['tensorboard_name'] = tb_logger.get_name()
    
    last_wish.register_action(lambda: save_backup(model, model_stash))

    print(file=sys.stderr)

    tasks = dataset.N_TASKS

    for t in range(tasks):
        model.net.train()
        train_loader, test_loader = dataset.get_data_loaders()
        if hasattr(model, 'begin_task'):
            model.begin_task(dataset)
        if t < model_stash['task_idx']:
            print('skipping task', t, file=sys.stderr)
            if t > 0 and args.csv_log:
                csv_logger.log(model_stash['mean_accs'][t - 1])
            if hasattr(model, 'end_task'):
                model.end_task(dataset)
            continue

        n_epochs = args.n_epochs
        if args.model == 'joint':
            n_epochs = 0

        for epoch in range(n_epochs):
            if t <= model_stash['task_idx'] and epoch < model_stash['epoch_idx']:
                print('skipping epoch', epoch, file=sys.stderr)
                continue

            for i, allData in enumerate(train_loader):
                data = allData

                if t <= model_stash['task_idx'] and epoch < model_stash[
                        'epoch_idx'] and i < model_stash['batch_idx']:
                    print('batch', epoch, file=sys.stderr)
                    continue

                inputs, labels, not_aug_inputs = data
                inputs, labels = inputs.to(model.device), labels.to(model.device)
                not_aug_inputs = not_aug_inputs.to(model.device)
                loss = model.observe(inputs, labels, not_aug_inputs)

                progress_bar(i, len(train_loader), epoch, t, loss)
                if args.tensorboard:
                    tb_logger.log_loss(loss, args, epoch, t, i)

                model_stash['batch_idx'] = i + 1
            model_stash['epoch_idx'] = epoch + 1
            model_stash['batch_idx'] = 0
            model.net.train()
        model_stash['task_idx'] = t + 1
        model_stash['epoch_idx'] = 0

        if hasattr(model, 'end_task'):
            model.end_task(dataset)

        accs = evaluate(model, dataset)
        mean_acc = np.mean(accs, axis=1)
        if n_epochs or t == dataset.N_TASKS - 1:
            print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)

        model_stash['mean_accs'].append(mean_acc)
        if args.csv_log:
            csv_logger.log(mean_acc)
        if args.tensorboard:
            tb_logger.log_accuracy(accs, mean_acc, args, t)

    if args.tensorboard:
        tb_logger.close()
    if args.csv_log:
        csv_logger.write(vars(args))