Exemplo n.º 1
0
    def multiply_learning_rate(self, factor):
        """used for adaptive learning rate adjustment"""
        new_lr = self.lr * factor

        for param_group in self.optimizer.param_groups:
            if new_lr != param_group[
                    'lr'] or Config.use_adaptive_learning_rate_adjustment:
                updated_lr = "updating lr from " + repr(
                    param_group['lr']) + " to " + (repr(
                        param_group['lr'] *
                        factor) if Config.use_adaptive_learning_rate_adjustment
                                                   else repr(self.lr * factor))
                print(updated_lr)
                with open(DataManager.get_results_file(), 'a+') as f:
                    f.write(updated_lr)
                    f.write('\n')

                if Config.use_adaptive_learning_rate_adjustment:
                    param_group['lr'] *= factor
                else:
                    param_group['lr'] = new_lr
Exemplo n.º 2
0
def train_epoch(model, train_loader, epoch, test_loader, device, augmentors=None, print_accuracy=False,
                drop_adaptive_learning_rate=False):
    """
    Run a single train epoch

    :param model: the network of type torch.nn.Module
    :param train_loader: the training dataset
    :param epoch: the current epoch
    :param test_loader: The test dataset loader
    :param print_accuracy: True if should test when printing batch info
    """
    model.train()
    loss = 0
    batch_idx = 0
    loops = 1 if not augmentors else 1 + len(augmentors)
    loops = 1 if Config.batch_by_batch else loops

    if Config.drop_learning_rate:
        if not Config.use_adaptive_learning_rate_adjustment:
            learning_rate_coefficient = 1 / pow(Config.drop_factor, math.floor(epoch / Config.drop_period))
            model.multiply_learning_rate(learning_rate_coefficient)
        elif drop_adaptive_learning_rate:
            model.multiply_learning_rate(1 / Config.drop_factor)

    for i in range(loops):
        if i == 0 and not Config.train_on_origonal_data and not Config.batch_by_batch:
            # skip origInal data in epoch splicing
            continue

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            # batch loop
            model.optimizer.zero_grad()

            has_augs = augmentors is not None and len(augmentors) > 0

            train_on_aug = has_augs and i >= 1
            train_on_aug = train_on_aug or (Config.batch_by_batch and has_augs)
            if train_on_aug:
                if not Config.batch_by_batch:
                    augmentor = augmentors[i - 1]
                    if augmentor is None:
                        continue
                    loss += train_batch(model, inputs, targets, device, augmentor=augmentor)
                else:
                    for augmentor in augmentors:
                        loss += train_batch(model, inputs, targets, device, augmentor=augmentor)
                        model.optimizer.zero_grad()

            train_on_original = i == 0
            train_on_original = train_on_original or Config.batch_by_batch

            if train_on_original:
                loss += train_batch(model, inputs, targets, device)

    if print_epoch_every != -1 and epoch % print_epoch_every == 0:
        if print_accuracy:
            test_acc = test(model, test_loader, device, print_acc=False)
            print("epoch", epoch, "average loss:", loss / batch_idx, "accuracy:", test_acc, "i = ", i)
            with open(DataManager.get_results_file(), 'a+') as f:
                f.write(repr(epoch) + ': ' + repr(test_acc) + ', loss: ' + repr(loss / batch_idx))
                f.write('\n')

            model.train()
            return test_acc

        else:
            print("epoch", epoch, "average loss:", loss / batch_idx, "i=", i)

    end_time = time.time()