Esempio n. 1
0
    def train(self):

        if self.resume:
            print('Resuming training ...')
            checkpoint = torch.load(os.path.join(self.log_root, 'torch_model'))
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            print('Starting training ...')

        writer = SummaryWriter(self.log_root)
        self.model = self.model.to(self.device)

        epoch = int(self.model.epoch) + 1
        it = int(self.model.iteration)
        for epoch in range(epoch, epoch + self.num_epoch):

            epoch_root = 'epoch_{:02d}'.format(epoch)
            if not os.path.exists(os.path.join(self.log_root, epoch_root)):
                os.makedirs(os.path.join(self.log_root, epoch_root))

            for phase in self.data_loaders.keys():
                epoch_loss = 0

                if phase == 'train':
                    self.model.train(True)
                else:
                    self.model.train(False)

                running_loss = 0.0
                for i, (data, index) in enumerate(self.data_loaders[phase]):
                    it += 1
                    # copy input and targets to the device object
                    inputs = data['input'].to(self.device)
                    targets = data['target'].to(self.device)
                    # zero the parameter gradients
                    self.optimizer.zero_grad()

                    # forward + backward + optimize
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, targets)

                    if phase == 'train':
                        loss.backward()
                        self.optimizer.step()

                    # print statistics
                    running_loss += loss.item()
                    epoch_loss += loss.item()
                    if (i + 1) % self.log_int == 0:
                        running_loss_avg = running_loss / self.log_int
                        print('Phase: ' + phase +
                              ', epoch: {}, batch {}: running loss: {:0.3f}'.
                              format(self.model.epoch, i +
                                     1, running_loss_avg))
                        writer.add_scalars('running_loss',
                                           {phase: running_loss_avg}, it)
                        running_loss = 0.0

                if phase in ['train', 'val']:
                    epoch_loss_avg = epoch_loss / self.data_lengths[phase]
                    print('Phase: ' + phase +
                          ', epoch: {}: epoch loss: {:0.3f}'.format(
                              epoch, epoch_loss_avg))
                    writer.add_scalars('epoch_loss', {phase: epoch_loss_avg},
                                       epoch)
                    writer.add_histogram(
                        'input histogram',
                        inputs.cpu().data.numpy()[0, 0].flatten(), epoch)
                    writer.add_histogram(
                        'output histogram',
                        outputs.cpu().data.numpy()[0, 0].flatten(), epoch)
                    figure_inds = list(range(inputs.shape[0]))
                    figure_inds = figure_inds if len(
                        figure_inds) < 4 else list(range(4))
                    fig = Trainer.show_imgs(inputs, outputs, figure_inds)
                    fig.savefig(
                        os.path.join(self.log_root, epoch_root,
                                     phase + '.png'))
                    writer.add_figure('images ' + phase, fig, epoch)

                if self.save & (phase == 'train'):

                    print('Writing model graph...')
                    writer.add_graph(self.model, inputs)

                    print('Saving model state...')
                    self.model.epoch = torch.nn.Parameter(torch.tensor(epoch),
                                                          requires_grad=False)
                    self.model.iteration = torch.nn.Parameter(
                        torch.tensor(it), requires_grad=False)
                    torch.save({
                        'model_state_dict': self.model.state_dict(),
                    },
                               os.path.join(self.log_root, epoch_root,
                                            'model_state_dict'))
                    torch.save(
                        {'optimizer_state_dict': self.optimizer.state_dict()},
                        os.path.join(self.log_root, 'optimizer_state_dict'))

        print('Finished training ...')
        writer.close()
        print('Writer closed ...')

        # dictionary of accuracy metrics for tune hyperparameter optimization
        return {"val_loss_avg": epoch_loss_avg}
Esempio n. 2
0
        loss = crossentropy(y, labels)
        vloss += loss.item()
        _, predicted = torch.max(y.data, 1)
        vcorrect += (predicted == labels).sum().item()
        vcount += BATCHSIZE
    return vloss/len(dataloader), 100.0*(1.0-vcorrect/vcount)

# Training
running_loss = 0.0
running_correct = 0
running_count = 0

# Add the graph to tensorboard
dataiter = iter(train_loader)
data, labels = dataiter.next()
writer.add_graph (model, data)
writer.flush()

# Cycle through epochs
for epoch in range(100):
    
    # Cycle through batches
    for batch_idx, (data, labels) in enumerate(train_loader):

        optimizer.zero_grad()
        y = model(data)
        loss = crossentropy(y, labels)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
Esempio n. 3
0
def main(args):
    # -----------------------------------------------------------------------------
    # Create model
    # -----------------------------------------------------------------------------
    if args.model == 'dicenet':
        from model.classification import dicenet as net
        model = net.CNNModel(args)
    elif args.model == 'espnetv2':
        from model.classification import espnetv2 as net
        model = net.EESPNet(args)
    elif args.model == 'shufflenetv2':
        from model.classification import shufflenetv2 as net
        model = net.CNNModel(args)
    else:
        print_error_message('Model {} not yet implemented'.format(args.model))
        exit()

    if args.finetune:
        # laod the weights for finetuning
        if os.path.isfile(args.weights_ft):
            pretrained_dict = torch.load(args.weights_ft,
                                         map_location=torch.device('cpu'))
            print_info_message('Loading pretrained basenet model weights')
            model_dict = model.state_dict()

            overlap_dict = {
                k: v
                for k, v in model_dict.items() if k in pretrained_dict
            }

            total_size_overlap = 0
            for k, v in enumerate(overlap_dict):
                total_size_overlap += torch.numel(overlap_dict[v])

            total_size_pretrain = 0
            for k, v in enumerate(pretrained_dict):
                total_size_pretrain += torch.numel(pretrained_dict[v])

            if len(overlap_dict) == 0:
                print_error_message(
                    'No overlaping weights between model file and pretrained weight file. Please check'
                )

            print_info_message('Overlap ratio of weights: {:.2f} %'.format(
                (total_size_overlap * 100.0) / total_size_pretrain))

            model_dict.update(overlap_dict)
            model.load_state_dict(model_dict, strict=False)
            print_info_message('Pretrained basenet model loaded!!')
        else:
            print_error_message('Unable to find the weights: {}'.format(
                args.weights_ft))

    # -----------------------------------------------------------------------------
    # Writer for logging
    # -----------------------------------------------------------------------------
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)
    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.randn(1, 3, args.inpSize,
                                                    args.inpSize))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    # network properties
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, args.inpSize, args.inpSize))
    print_info_message('FLOPs: {:.2f} million'.format(flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    best_acc = 0.0
    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus >= 1 else 'cpu'
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'],
                                  map_location=torch.device(device))
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    # -----------------------------------------------------------------------------
    # Loss Fn
    # -----------------------------------------------------------------------------
    if args.dataset == 'imagenet':
        criterion = nn.CrossEntropyLoss()
        acc_metric = 'Top-1'
    elif args.dataset == 'coco':
        criterion = nn.BCEWithLogitsLoss()
        acc_metric = 'F1'
    else:
        print_error_message('{} dataset not yet supported'.format(
            args.dataset))

    if num_gpus >= 1:
        model = torch.nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # -----------------------------------------------------------------------------
    # Data Loaders
    # -----------------------------------------------------------------------------
    # Data loading code
    if args.dataset == 'imagenet':
        train_loader, val_loader = img_loader.data_loaders(args)
        # import the loaders too
        from utilities.train_eval_classification import train, validate
    elif args.dataset == 'coco':
        from data_loader.classification.coco import COCOClassification
        train_dataset = COCOClassification(root=args.data,
                                           split='train',
                                           year='2017',
                                           inp_size=args.inpSize,
                                           scale=args.scale,
                                           is_training=True)
        val_dataset = COCOClassification(root=args.data,
                                         split='val',
                                         year='2017',
                                         inp_size=args.inpSize,
                                         is_training=False)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

        # import the loaders too
        from utilities.train_eval_classification import train_multi as train
        from utilities.train_eval_classification import validate_multi as validate
    else:
        print_error_message('{} dataset not yet supported'.format(
            args.dataset))

    # -----------------------------------------------------------------------------
    # LR schedulers
    # -----------------------------------------------------------------------------
    if args.scheduler == 'fixed':
        step_sizes = args.steps
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        from utilities.lr_scheduler import CyclicLR
        step_sizes = args.steps
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max)
    else:
        print_error_message('Scheduler ({}) not yet implemented'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    # set up the epoch variable in case resuming training
    if args.start_epoch != 0:
        for epoch in range(args.start_epoch):
            lr_scheduler.step(epoch)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    # -----------------------------------------------------------------------------
    # Training and Val Loop
    # -----------------------------------------------------------------------------

    extra_info_ckpt = args.model + '_' + str(args.s)
    for epoch in range(args.start_epoch, args.epochs):
        lr_log = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_log
        print_info_message("LR for epoch {} = {:.5f}".format(epoch, lr_log))
        train_acc, train_loss = train(data_loader=train_loader,
                                      model=model,
                                      criteria=criterion,
                                      optimizer=optimizer,
                                      epoch=epoch,
                                      device=device)
        # evaluate on validation set
        val_acc, val_loss = validate(data_loader=val_loader,
                                     model=model,
                                     criteria=criterion,
                                     device=device)

        # remember best prec@1 and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': weights_dict,
                'best_prec1': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Classification/LR/learning_rate', lr_log, epoch)
        writer.add_scalar('Classification/Loss/Train', train_loss, epoch)
        writer.add_scalar('Classification/Loss/Val', val_loss, epoch)
        writer.add_scalar('Classification/{}/Train'.format(acc_metric),
                          train_acc, epoch)
        writer.add_scalar('Classification/{}/Val'.format(acc_metric), val_acc,
                          epoch)
        writer.add_scalar('Classification/Complexity/Top1_vs_flops', best_acc,
                          round(flops, 2))
        writer.add_scalar('Classification/Complexity/Top1_vs_params', best_acc,
                          round(num_params, 2))

    writer.close()
Esempio n. 4
0
class Train():
    def __init__(self, config):
        self.config = config

        ATTR_HEAD = {'race': RaceHead, 'gender': GenderHead,
                     'age': AgeHead, 'recognition': self.config.recognition_head}

        self.writer = SummaryWriter(config.log_path)

        if path.isfile(self.config.train_source):
            self.train_loader = LMDBDataLoader(self.config, self.config.train_source)
        else:
            self.train_loader = CustomDataLoader(self.config, self.config.train_source,
                                                 self.config.train_list)

        class_num = self.train_loader.class_num()
        print(len(self.train_loader.dataset))
        print(f'Classes: {class_num}')

        self.model = ResNet(self.config.depth, self.config.drop_ratio, self.config.net_mode)
        self.head = ATTR_HEAD[self.config.attribute](classnum=class_num)

        paras_only_bn, paras_wo_bn = separate_bn_param(self.model)

        dummy_input = torch.zeros(1, 3, 112, 112)
        self.writer.add_graph(self.model, dummy_input)

        if torch.cuda.device_count() > 1:
            print(f"Model will use {torch.cuda.device_count()} GPUs!")
            self.model = DataParallel(self.model)
            self.head = DataParallel(self.head)

        self.model = self.model.to(self.config.device)
        self.head = self.head.to(self.config.device)

        self.weights = None
        if self.config.attribute in ['race', 'gender']:
            _, self.weights = np.unique(self.train_loader.dataset.get_targets(), return_counts=True)
            self.weights = np.max(self.weights) / self.weights
            self.weights = torch.tensor(self.weights, dtype=torch.float, device=self.config.device)
            self.config.weights = self.weights
            print(self.weights)

        if self.config.val_source is not None:
            if self.config.attribute != 'recognition':
                if path.isfile(self.config.val_source):
                    self.val_loader = LMDBDataLoader(self.config, self.config.val_source, False)
                else:
                    self.val_loader = CustomDataLoader(self.config, self.config.val_source,
                                                       self.config.val_list, False)

            else:
                self.validation_list = []
                for val_name in config.val_list:
                    dataset, issame = get_val_pair(self.config.val_source, val_name)
                    self.validation_list.append([dataset, issame, val_name])

        self.optimizer = optim.SGD([{'params': paras_wo_bn,
                                     'weight_decay': self.config.weight_decay},
                                    {'params': self.head.parameters(),
                                     'weight_decay': self.config.weight_decay},
                                    {'params': paras_only_bn}],
                                   lr=self.config.lr, momentum=self.config.momentum)

        if self.config.resume:
            print(f'Resuming training from {self.config.resume}')
            load_state(self.model, self.head, self.optimizer, self.config.resume, False)

        if self.config.pretrained:
            print(f'Loading pretrained weights from {self.config.pretrained}')
            load_state(self.model, self.head, None, self.config.pretrained, True)

        print(self.config)
        self.save_file(self.config, 'config.txt')

        print(self.optimizer)
        self.save_file(self.optimizer, 'optimizer.txt')

        self.tensorboard_loss_every = max(len(self.train_loader) // 100, 1)
        self.evaluate_every = max(len(self.train_loader) // 5, 1)

        if self.config.lr_plateau:
            self.scheduler = ReduceLROnPlateau(self.optimizer, mode=self.config.max_or_min, factor=0.1,
                                               patience=3, verbose=True, threshold=0.001, cooldown=1)
        if self.config.early_stop:
            self.early_stop = EarlyStop(mode=self.config.max_or_min)

    def run(self):
        self.model.train()
        self.head.train()
        running_loss = 0.
        step = 0
        val_acc = 0.
        val_loss = 0.

        best_step = 0
        best_acc = float('Inf')
        if self.config.max_or_min == 'max':
            best_acc *= -1

        for epoch in range(self.config.epochs):
            train_logger = TrainLogger(self.config.batch_size, self.config.frequency_log)

            if epoch + 1 in self.config.reduce_lr and not self.config.lr_plateau:
                self.reduce_lr()

            for idx, data in enumerate(self.train_loader):
                imgs, labels = data
                imgs = imgs.to(self.config.device)
                labels = labels.to(self.config.device)

                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                if self.config.attribute == 'recognition':
                    outputs = self.head(embeddings, labels)
                else:
                    outputs = self.head(embeddings)

                if self.weights is not None:
                    loss = self.config.loss(outputs, labels, weight=self.weights)
                else:
                    loss = self.config.loss(outputs, labels)

                loss.backward()
                running_loss += loss.item()

                self.optimizer.step()

                if step % self.tensorboard_loss_every == 0:
                    loss_board = running_loss / self.tensorboard_loss_every
                    self.writer.add_scalar('train_loss', loss_board, step)
                    running_loss = 0.

                if step % self.evaluate_every == 0 and step != 0:
                    if self.config.val_source is not None:
                        val_acc, val_loss = self.evaluate(step)
                        self.model.train()
                        self.head.train()
                        best_acc, best_step = self.save_model(val_acc, best_acc, step, best_step)
                        print(f'Best accuracy: {best_acc:.5f} at step {best_step}')
                    else:
                        save_state(self.model, self.head, self.optimizer, self.config, 0, step)

                train_logger(epoch, self.config.epochs, idx, len(self.train_loader), loss.item())
                step += 1

            if self.config.lr_plateau:
                self.scheduler.step(val_acc)

            if self.config.early_stop:
                self.early_stop(val_acc)
                if self.early_stop.stop:
                    print("Early stopping model...")
                    break

        val_acc, val_loss = self.evaluate(step)
        best_acc = self.save_model(val_acc, best_acc, step, best_step)
        print(f'Best accuracy: {best_acc} at step {best_step}')

    def save_model(self, val_acc, best_acc, step, best_step):
        if (self.config.max_or_min == 'max' and val_acc > best_acc) or \
           (self.config.max_or_min == 'min' and val_acc < best_acc):
            best_acc = val_acc
            best_step = step
            save_state(self.model, self.head, self.optimizer, self.config, val_acc, step)

        return best_acc, best_step

    def reduce_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10

        print(self.optimizer)

    def tensorboard_val(self, accuracy, step, loss=0, dataset=''):
        self.writer.add_scalar('{}val_acc'.format(dataset), accuracy, step)

        if self.config.attribute != 'recognition':
            self.writer.add_scalar('val_loss', loss, step)

    def evaluate(self, step):
        if self.config.attribute != 'recognition':
            val_acc, val_loss = self.evaluate_attribute()
            self.tensorboard_val(val_acc, step, val_loss)

        elif self.config.attribute == 'recognition':
            val_loss = 0
            val_acc = 0
            print('Validating...')
            for idx, validation in enumerate(self.validation_list):
                dataset, issame, val_name = validation
                acc, std = self.evaluate_recognition(dataset, issame)
                self.tensorboard_val(acc, step, dataset=f'{val_name}_')
                print(f'{val_name}: {acc:.5f}+-{std:.5f}')
                val_acc += acc

            val_acc /= (idx + 1)
            self.tensorboard_val(val_acc, step)
            print(f'Mean accuracy: {val_acc:.5f}')

        return val_acc, val_loss

    def evaluate_attribute(self):
        self.model.eval()
        self.head.eval()

        y_true = torch.tensor([], dtype=self.config.output_type, device=self.config.device)
        all_outputs = torch.tensor([], device=self.config.device)

        with torch.no_grad():
            for imgs, labels in iter(self.val_loader):
                imgs = imgs.to(self.config.device)
                labels = labels.to(self.config.device)

                embeddings = self.model(imgs)
                outputs = self.head(embeddings)

                y_true = torch.cat((y_true, labels), 0)
                all_outputs = torch.cat((all_outputs, outputs), 0)

            if self.weights is not None:
                loss = round(self.config.loss(all_outputs, y_true, weight=self.weights).item(), 4)
            else:
                loss = round(self.config.loss(all_outputs, y_true).item(), 4)

        y_true = y_true.cpu().numpy()

        if self.config.attribute == 'age':
            y_pred = all_outputs.cpu().numpy()
            y_pred = np.round(y_pred, 0)
            y_pred = np.sum(y_pred, axis=1)
            y_true = np.sum(y_true, axis=1)
            accuracy = round(mean_absolute_error(y_true, y_pred), 4)
        else:
            _, y_pred = torch.max(all_outputs, 1)
            y_pred = y_pred.cpu().numpy()

            accuracy = round(np.sum(y_true == y_pred) / len(y_pred), 4)

        return accuracy, loss

    def evaluate_recognition(self, samples, issame, nrof_folds=10, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(samples), self.config.embedding_size])

        with torch.no_grad():
            for idx in range(0, len(samples), self.config.batch_size):
                batch = torch.tensor(samples[idx:idx + self.config.batch_size])
                embeddings[idx:idx + self.config.batch_size] = self.model(batch.to(self.config.device)).cpu()
                idx += self.config.batch_size

        tpr, fpr, accuracy, best_thresholds = verification.evaluate(embeddings, issame, nrof_folds)

        return round(accuracy.mean(), 5), round(accuracy.std(), 5)

    def save_file(self, string, file_name):
        file = open(path.join(self.config.work_path, file_name), "w")
        file.write(str(string))
        file.close()
Esempio n. 5
0
class ConvNetRunner:
    def __init__(self, args):
        self.run_name = args.run_name + '_' + str(time.time()).split('.')[0]
        self.current_run_basepath = args.network_metrics_basepath + '/' + self.run_name + '/'
        self.learning_rate = args.learning_rate
        self.epochs = args.epochs
        self.test_net = args.test_net
        self.train_net = args.train_net
        self.batch_size = args.batch_size
        self.num_classes = args.num_classes
        self.audio_basepath = args.audio_basepath
        self.train_data_file = args.train_data_file
        self.test_data_file = args.test_data_file
        self.data_read_path = args.data_save_path
        self.is_cuda_available = torch.cuda.is_available()
        self.display_interval = args.display_interval
        self.sampling_rate = args.sampling_rate
        self.sample_size_in_seconds = args.sample_size_in_seconds
        self.overlap = args.overlap

        self.network_metrics_basepath = args.network_metrics_basepath
        self.tensorboard_summary_path = self.current_run_basepath + args.tensorboard_summary_path
        self.network_save_path = self.current_run_basepath + args.network_save_path

        self.network_restore_path = args.network_restore_path

        self.device = torch.device("cuda" if self.is_cuda_available else "cpu")
        self.network_save_interval = args.network_save_interval
        self.normalise = args.normalise_while_training
        self.dropout = args.dropout
        self.threshold = args.threshold
        self.debug_filename = self.current_run_basepath + '/' + args.debug_filename

        paths = [self.network_save_path, self.tensorboard_summary_path]
        file_utils.create_dirs(paths)

        self.network = ConvNet().to(self.device)
        self.pos_weight = None
        self.loss_function = None
        self.learning_rate_decay = args.learning_rate_decay

        self.optimiser = optim.Adam(self.network.parameters(),
                                    lr=self.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimiser, gamma=self.learning_rate_decay)

        self._min, self._max = float('inf'), -float('inf')

        if self.train_net:
            self.network.train()
            self.log_file = open(
                self.network_save_path + '/' + self.run_name + '.log', 'w')
            self.log_file.write(json.dumps(args))
        if self.test_net:
            print('Loading Network')
            self.network.load_state_dict(
                torch.load(self.network_restore_path,
                           map_location=self.device))
            self.network.eval()
            self.log_file = open(
                self.network_restore_path.replace('_40.pt', '.log'), 'a')
            print(
                '\n\n\n********************************************************',
                file=self.log_file)
            print('Testing Model - ', self.network_restore_path)
            print('Testing Model - ',
                  self.network_restore_path,
                  file=self.log_file)
            print('********************************************************',
                  file=self.log_file)

        self.writer = SummaryWriter(self.tensorboard_summary_path)
        print("Network config:\n", self.network)
        print("Network config:\n", self.network, file=self.log_file)

        self.batch_loss, self.batch_accuracy, self.uar = [], [], []

        print('Configs used:\n', json.dumps(args, indent=4))
        print('Configs used:\n',
              json.dumps(args, indent=4),
              file=self.log_file)

    def data_reader(self,
                    data_filepath,
                    label_filepath,
                    train,
                    should_batch=True,
                    shuffle=True,
                    infer=False):
        if infer:
            pass
        else:
            input_data, labels = read_npy(data_filepath), read_npy(
                label_filepath)
            if train:

                print('Original data size - before Augmentation')
                print('Original data size - before Augmentation',
                      file=self.log_file)
                print('Total data ', len(input_data))
                print('Event rate', sum(labels) / len(labels))
                print(np.array(input_data).shape, np.array(labels).shape)

                print('Total data ', len(input_data), file=self.log_file)
                print('Event rate',
                      sum(labels) / len(labels),
                      file=self.log_file)
                print(np.array(input_data).shape,
                      np.array(labels).shape,
                      file=self.log_file)

                for x in input_data:
                    self._min = min(np.min(x), self._min)
                    self._max = max(np.max(x), self._max)

                print('Data Augmentation starts . . .')
                print('Data Augmentation starts . . .', file=self.log_file)
                label_to_augment = 1
                amount_to_augment = 1.3
                ones_ids = [
                    idx for idx, x in enumerate(labels)
                    if x == label_to_augment
                ]
                random_idxs = random.choices(
                    ones_ids, k=int(len(ones_ids) * amount_to_augment))
                data_to_augment = input_data[random_idxs]
                augmented_data = []
                augmented_labels = []
                for x in data_to_augment:
                    x = librosaSpectro_to_torchTensor(x)
                    x = random.choice([time_mask, freq_mask])(x)[0].numpy()
                    augmented_data.append(x), augmented_labels.append(
                        label_to_augment)

                input_data = np.concatenate((input_data, augmented_data))
                labels = np.concatenate((labels, augmented_labels))

                print('Data Augmentation done . . .')
                print('Data Augmentation done . . .', file=self.log_file)

                data = [(x, y) for x, y in zip(input_data, labels)]
                random.shuffle(data)
                input_data, labels = np.array([x[0] for x in data
                                               ]), [x[1] for x in data]

                # Initialize pos_weight based on training data
                self.pos_weight = len([x for x in labels if x == 0]) / len(
                    [x for x in labels if x == 1])
                print('Pos weight for the train data - ', self.pos_weight)
                print('Pos weight for the train data - ',
                      self.pos_weight,
                      file=self.log_file)

            print('Total data ', len(input_data))
            print('Event rate', sum(labels) / len(labels))
            print(np.array(input_data).shape, np.array(labels).shape)

            print('Total data ', len(input_data), file=self.log_file)
            print('Event rate', sum(labels) / len(labels), file=self.log_file)
            print(np.array(input_data).shape,
                  np.array(labels).shape,
                  file=self.log_file)

            print('Min max values used for normalisation ', self._min,
                  self._max)
            print('Min max values used for normalisation ',
                  self._min,
                  self._max,
                  file=self.log_file)

            # Normalizing `input data` on train dataset's min and max values
            if self.normalise:
                input_data = (input_data - self._min) / (self._max - self._min)

            if should_batch:
                batched_input = [
                    input_data[pos:pos + self.batch_size]
                    for pos in range(0, len(input_data), self.batch_size)
                ]
                batched_labels = [
                    labels[pos:pos + self.batch_size]
                    for pos in range(0, len(labels), self.batch_size)
                ]
                return batched_input, batched_labels
            else:
                return input_data, labels

    def run_for_epoch(self, epoch, x, y, type):
        self.network.eval()
        for m in self.network.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.track_running_stats = False
        predictions_dict = {"tp": [], "fp": [], "tn": [], "fn": []}
        predictions = []
        self.test_batch_loss, self.test_batch_accuracy, self.test_batch_uar, self.test_batch_ua, self.test_batch_f1, self.test_batch_precision, self.test_batch_recall, audio_for_tensorboard_test = [], [], [], [], [], [], [], None
        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(x, y)):
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                test_predictions = self.network(audio_data).squeeze(1)
                test_loss = self.loss_function(test_predictions, label)
                test_predictions = nn.Sigmoid()(test_predictions)
                predictions.append(to_numpy(test_predictions))
                test_accuracy, test_uar, test_precision, test_recall, test_f1 = accuracy_fn(
                    test_predictions, label, self.threshold)
                self.test_batch_loss.append(to_numpy(test_loss))
                self.test_batch_accuracy.append(to_numpy(test_accuracy))
                self.test_batch_uar.append(test_uar)
                self.test_batch_f1.append(test_f1)
                self.test_batch_precision.append(test_precision)
                self.test_batch_recall.append(test_recall)

                tp, fp, tn, fn = custom_confusion_matrix(
                    test_predictions, label, threshold=self.threshold)
                predictions_dict['tp'].extend(tp)
                predictions_dict['fp'].extend(fp)
                predictions_dict['tn'].extend(tn)
                predictions_dict['fn'].extend(fn)

        print(f'***** {type} Metrics ***** ')
        print(f'***** {type} Metrics ***** ', file=self.log_file)
        print(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}"
        )
        print(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}",
            file=self.log_file)

        log_summary(self.writer,
                    epoch,
                    accuracy=np.mean(self.test_batch_accuracy),
                    loss=np.mean(self.test_batch_loss),
                    uar=np.mean(self.test_batch_uar),
                    lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                    type=type)
        log_conf_matrix(self.writer,
                        epoch,
                        predictions_dict=predictions_dict,
                        type=type)

        y = [element for sublist in y for element in sublist]
        predictions = [
            element for sublist in predictions for element in sublist
        ]
        write_to_npy(filename=self.debug_filename,
                     predictions=predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=np.mean(self.test_batch_accuracy),
                     loss=np.mean(self.test_batch_loss),
                     uar=np.mean(self.test_batch_uar),
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     predictions_dict=predictions_dict,
                     type=type)

    def train(self):

        # For purposes of calculating normalized values, call this method with train data followed by test
        train_data, train_labels = self.data_reader(
            self.data_read_path + 'train_challenge_with_d1_data.npy',
            self.data_read_path + 'train_challenge_with_d1_labels.npy',
            shuffle=True,
            train=True)
        dev_data, dev_labels = self.data_reader(
            self.data_read_path + 'dev_challenge_with_d1_data.npy',
            self.data_read_path + 'dev_challenge_with_d1_labels.npy',
            shuffle=False,
            train=False)
        test_data, test_labels = self.data_reader(
            self.data_read_path + 'test_challenge_data.npy',
            self.data_read_path + 'test_challenge_labels.npy',
            shuffle=False,
            train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        self.loss_function = nn.BCEWithLogitsLoss(
            pos_weight=to_tensor(self.pos_weight, device=self.device))

        total_step = len(train_data)
        for epoch in range(1, self.epochs):
            self.network.train()
            self.batch_loss, self.batch_accuracy, self.batch_uar, self.batch_f1, self.batch_precision, self.batch_recall, audio_for_tensorboard_train = [], [], [], [], [], [], None
            for i, (audio_data,
                    label) in enumerate(zip(train_data, train_labels)):
                self.optimiser.zero_grad()
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                if i == 0:
                    self.writer.add_graph(self.network, audio_data)
                predictions = self.network(audio_data).squeeze(1)
                loss = self.loss_function(predictions, label)
                predictions = nn.Sigmoid()(predictions)
                loss.backward()
                self.optimiser.step()
                accuracy, uar, precision, recall, f1 = accuracy_fn(
                    predictions, label, self.threshold)
                self.batch_loss.append(to_numpy(loss))
                self.batch_accuracy.append(to_numpy(accuracy))
                self.batch_uar.append(uar)
                self.batch_f1.append(f1)
                self.batch_precision.append(precision)
                self.batch_recall.append(recall)

                if i % self.display_interval == 0:
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {'%.3f' % accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}"
                    )
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}",
                        file=self.log_file)

            # Decay learning rate
            self.scheduler.step(epoch=epoch)
            log_summary(
                self.writer,
                epoch,
                accuracy=np.mean(self.batch_accuracy),
                loss=np.mean(self.batch_loss),
                uar=np.mean(self.batch_uar),
                lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                type='Train')
            print('***** Overall Train Metrics ***** ')
            print('***** Overall Train Metrics ***** ', file=self.log_file)
            print(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}"
            )
            print(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}",
                file=self.log_file)
            print('Learning rate ',
                  self.optimiser.state_dict()['param_groups'][0]['lr'])
            print('Learning rate ',
                  self.optimiser.state_dict()['param_groups'][0]['lr'],
                  file=self.log_file)

            # dev data
            self.run_for_epoch(epoch, dev_data, dev_labels, type='Dev')

            # test data
            self.run_for_epoch(epoch, test_data, test_labels, type='Test')

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                torch.save(self.network.state_dict(), save_path)
                print('Network successfully saved: ' + save_path)

    def test(self):
        test_data, test_labels = self.data_reader(
            self.data_read_path + 'test_challenge_data.npy',
            self.data_read_path + 'test_challenge_labels.npy',
            shuffle=False,
            train=False)
        test_predictions = self.network(test_data).squeeze(1)
        test_predictions = nn.Sigmoid()(test_predictions)
        test_accuracy, test_uar = accuracy_fn(test_predictions, test_labels,
                                              self.threshold)
        print(f"Accuracy: {test_accuracy} | UAR: {test_uar}")
        print(f"Accuracy: {test_accuracy} | UAR: {test_uar}",
              file=self.log_file)

    def infer(self, data_file):
        test_data = self.data_reader(data_file,
                                     shuffle=False,
                                     train=False,
                                     infer=True)
        test_predictions = self.network(test_data).squeeze(1)
        test_predictions = nn.Sigmoid()(test_predictions)
        return test_predictions
Esempio n. 6
0
class MyTensorBoard():
    def __init__(self, net, LabelStr, EventDir):
        self.labelStr = labelStr
        self.writer = SummaryWriter(EventDir + '/')
        self.net = net

    def matplotlib_imshow(self, img, one_channel = True):
        if one_channel:
            img = img.mean(dim = 0)
        img = img / 2 + 0.5
        npimg = img.numpy()
        if one_channel:
            plt.imshow(npimg, cmap = 'Greys')
        else:
            plt.imshow(np.transpose(npimg, (1, 2, 0)))

    def ImageVisualize(self, images, labels):
        img_grid = torchvision.utils.make_grid(images)
        self.matplotlib_imshow(img_grid, one_channel = True)
        self.writer.add_image('Images', img_grid)
        self.writer.close()

    # Add Net structure to Tensorboard
    def NetVisualize(self, sampleInput):
        self.writer.add_graph(self.net, sampleInput)
        self.writer.close()

    def images_to_probs(self, images):
        '''
        Generates predictions and corresponding probabilities from a trained network
        and a list of images
        '''
        output = net(images)
        _, preds_tensor = torch.max(output, 1)
        preds = np.squeeze(preds_tensor.numpy())
        return preds, [F.softmax(el, dim = 0)[i].item() for i, el in zip(preds, output)]

    def plot_classes_preds(self, images, labels):
        '''
            Generates matplotlib Figure using a trained network, along with images and labels
            from a batch, that shows the network's top predictions along with its probability,
            alongside the actual label, coloring this information based on whether the predictions
            was correct or not. Uses the "Images_to_probs" function
        '''
        preds, probs = images_to_probs(images)
        fig = plt.figure(figsize = (12, 48))
        for idx in np.arange(4):
            ax = fig.add_subplot(1, 4, idx + 1, xticks = [], yticks = [])
            matplotlib_imshow(images[idx], one_channel = True)
            ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(self.labelStr[preds[idx]],
                                                              probs[idx] * 100.0,
                                                              self.labelStr[labels[idx]]),
                                                              color = ("green" if preds[idx] == labels[idx].item() else "red"))
        return fig

    def ScalarVisualize(self, graphTitle, loss, currentStep):
        '''
            Log scalar values to plots, e.g. loss, acc, during training
        '''
        self.writer.add_scalar(graphTitle, loss, currentStep)
        self.writer.close()
    
    def PredVisualize(self, images, labels, currentStep):
        '''
            Log matplotlib figures of model's predictions to specified mini-batch
        '''
        self.writer.add_figure('predictions vs. actuals',
                               self.plot_classes_preds(images, labels),
                               global_step = currentStep)
        self.writer.close()

    def ProjVisualize(self, data, labels, use_rand_instance = True, num_rand = 100):
        '''
            Add projection visualization to Tensorboard
        '''
        assert len(data) == len(labels)
        if use_rand_instance:
            perm = torch.randperm(len(data))
            images, labels = data[perm][:num_rand], labels[perm][:num_rand]
        else:
            images, labels = data, labels
        class_labels = [self.labelStr[lab] for lab in labels]
        features = images.view(-1, 28 * 28)
        self.writer.add_embedding(features,
                                  metadata = class_labels,
                                  label_img = images.unsqueeze(1))
        self.writer.close()

    def PRcurveVisualize(self, test_probs, test_preds):
        '''
            Plot the Precision - Recall curve in Tensorboard, per - class wise
        '''
        for class_index in range(len(self.labelStr)):
            tensorboard_preds = test_preds == class_index
            tensorboard_probs = test_probs[:, class_index]
            self.writer.add_pr_curve(self.labelStr[class_index],
                                     tensorboard_preds,
                                     tensorboard_probs,
                                     global_step = 0)
        self.writer.close()
    classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=.001, momentum=.9)
    writer = SummaryWriter('runs/fashion_mnist_experiment_1')

    dataiter = iter(trainloader)
    images, labels = dataiter.next()

    img_grid = torchvision.utils.make_grid(images)

    matplotlib_imshow(img_grid, one_channel=True)
    writer.add_image('four_fashion_mnist_images', img_grid)
    writer.add_graph(net,images)

    images, labels = select_n_random(trainset.data, trainset.targets)

    class_labels = [classes[lab] for lab in labels]

    features = images.view(-1, 28 * 28)
    writer.add_embedding(features,
                        metadata=class_labels,
                        label_img=images.unsqueeze(1))
    
    running_loss = 0
    for epoch in range(1):
        for i, data in enumerate(trainloader,0):
            inputs, labels = data
Esempio n. 8
0
        return out_stage6


if __name__ == "__main__":
    import numpy as np
    import cv2
    import util

    test_image2 = 'aa.jpg'  # 비교할 이미지 2

    oriImg = cv2.imread(test_image2)
    scale = 0.5786163522012578
    stride = 8
    padValue = 128
    imageToTest = cv2.resize(oriImg, (0, 0),
                             fx=scale,
                             fy=scale,
                             interpolation=cv2.INTER_CUBIC)
    imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride,
                                                      padValue)
    im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]),
                      (3, 2, 0, 1)) / 256 - 0.5
    im = np.ascontiguousarray(im)

    data = torch.from_numpy(im).float()
    net = bodypose_model()
    net.eval()
    writer = SummaryWriter('runs')
    writer.add_graph(net, data)
    writer.close()
Esempio n. 9
0
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    # Train the model
    total_step = len(dataloader)
    for epoch in range(num_epochs):  # Loop over the dataset multiple times
        train_loss = 0
        for step, (seq, label) in enumerate(dataloader):
            # Forward pass
            seq = seq.clone().detach().view(-1, window_size,
                                            input_size).to(device)
            output = model(seq)
            loss = criterion(output, label.to(device))

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            writer.add_graph(model, seq)
        print('Epoch [{}/{}], train_loss: {:.4f}'.format(
            epoch + 1, num_epochs, train_loss / total_step))
        writer.add_scalar('train_loss', train_loss / total_step, epoch + 1)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    torch.save(model.state_dict(), model_dir + '/' + log + '.pt')
    writer.close()
    print('Finished Training')
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch Tensorboard Example')
    parser.add_argument('--cluster_mode', type=str, default="local",
                        help='The cluster mode, such as local, yarn, spark-submit or k8s.')
    parser.add_argument('--backend', type=str, default="bigdl",
                        help='The backend of PyTorch Estimator; '
                             'bigdl, torch_distributed and spark are supported.')
    parser.add_argument('--batch_size', type=int, default=64, help='The training batch size')
    parser.add_argument('--epochs', type=int, default=2, help='The number of epochs to train for')
    args = parser.parse_args()

    if args.cluster_mode == "local":
        init_orca_context()
    elif args.cluster_mode == "yarn":
        init_orca_context(cluster_mode=args.cluster_mode, cores=4, num_nodes=2)
    elif args.cluster_mode == "spark-submit":
        init_orca_context(cluster_mode=args.cluster_mode)

    tensorboard_dir = "runs"
    writer = SummaryWriter(tensorboard_dir + '/fashion_mnist_experiment_1')
    # constant for classes
    classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

    # plot some random training images
    dataiter = iter(train_data_creator(config={}, batch_size=4))
    images, labels = dataiter.next()

    # create grid of images
    img_grid = torchvision.utils.make_grid(images)

    # show images
    matplotlib_imshow(img_grid, one_channel=True)

    # write to tensorboard
    writer.add_image('four_fashion_mnist_images', img_grid)

    # inspect the model using tensorboard
    writer.add_graph(model_creator(config={}), images)
    writer.close()

    # training loss vs. epochs
    criterion = nn.CrossEntropyLoss()
    batch_size = args.batch_size
    epochs = args.epochs
    if args.backend == "bigdl":
        train_loader = train_data_creator(config={}, batch_size=batch_size)
        test_loader = validation_data_creator(config={}, batch_size=batch_size)

        net = model_creator(config={})
        optimizer = optimizer_creator(model=net, config={"lr": 0.001})
        orca_estimator = Estimator.from_torch(model=net,
                                              optimizer=optimizer,
                                              loss=criterion,
                                              metrics=[Accuracy()],
                                              backend="bigdl")

        orca_estimator.set_tensorboard(tensorboard_dir, "bigdl")

        orca_estimator.fit(data=train_loader, epochs=epochs, validation_data=test_loader,
                           checkpoint_trigger=EveryEpoch())

        res = orca_estimator.evaluate(data=test_loader)
        print("Accuracy of the network on the test images: %s" % res)
    elif args.backend in ["torch_distributed", "spark"]:
        orca_estimator = Estimator.from_torch(model=model_creator,
                                              optimizer=optimizer_creator,
                                              loss=criterion,
                                              metrics=[Accuracy()],
                                              backend=args.backend)
        stats = orca_estimator.fit(train_data_creator, epochs=epochs, batch_size=batch_size)

        for stat in stats:
            writer.add_scalar("training_loss", stat['train_loss'], stat['epoch'])
        print("Train stats: {}".format(stats))
        val_stats = orca_estimator.evaluate(validation_data_creator, batch_size=batch_size)
        print("Validation stats: {}".format(val_stats))
        orca_estimator.shutdown()
    else:
        raise NotImplementedError("Only bigdl and torch_distributed are supported "
                                  "as the backend, but got {}".format(args.backend))

    stop_orca_context()
def train_softmax(dataset_dir,
                  weights_dir=None,
                  run_name="run1",
                  epochs=80,
                  continue_epoch=None,
                  on_gpu=True,
                  checkpoint_dir="checkpoints",
                  batch_size=64,
                  print_interval=50,
                  num_features=3):
    writer = SummaryWriter(f"runs/{run_name}")

    if dataset_dir[-1] != '/':
        dataset_dir += '/'
    dataset = ClassificationDataset(dataset_dir,
                                    "annotations/instances_train2017.json",
                                    image_folder_path="train2017")
    dataset_length = len(dataset)
    val_dataset = ClassificationDataset(dataset_dir,
                                        "annotations/instances_val2017.json",
                                        image_folder_path="val2017")
    val_dataset_length = len(val_dataset)

    train_length = int(math.ceil(dataset_length * 0.3))
    test_length = dataset_length - train_length
    train_set, _ = torch.utils.data.random_split(dataset,
                                                 [train_length, test_length])

    dataloader = torch.utils.data.DataLoader(train_set,
                                             shuffle=True,
                                             batch_size=batch_size,
                                             num_workers=2)
    test_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                  shuffle=False,
                                                  batch_size=batch_size,
                                                  num_workers=2)

    num_classes = dataset.get_num_classes()
    print(f"Number of classes: {num_classes}")
    model = FeatureExtractorNet(use_classifier=True,
                                num_features=num_features,
                                num_classes=num_classes)

    for index, child in enumerate(model.backbone.children()):
        if index >= 15:  # This will make the last 4 layers trainable
            for param in child.parameters():
                param.requires_grad = True
        else:
            for param in child.parameters():
                param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True
    """for param in model.bottleneck.parameters():
        param.requires_grad = True"""

    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if weights_dir:
        print(f"Continuing training using weights {weights_dir}")
        model.load_state_dict(torch.load(weights_dir))

    example_input = None
    for data in dataloader:
        images, labels = data
        example_input = images
        break

    writer.add_graph(model, example_input)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.001,
                                 weight_decay=0.01)  # lr 0.001 default

    print(
        f"Training with {train_length} train images, and {val_dataset_length} test images"
    )

    if on_gpu:
        model = model.cuda()

    running_loss = 0.0

    # here we start training
    got_examples = False
    start = 0 if continue_epoch is None else continue_epoch
    for epoch in range(start, epochs):
        model.train()
        for i, data in enumerate(tqdm(dataloader)):
            inputs, labels = data
            if not got_examples:
                unnormalized_inputs = None  #unnormalize_image(inputs, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                for batch_i, batch_img in enumerate(inputs):
                    unnormalized_image = unnormalize_image(batch_img,
                                                           mean=(0.485, 0.456,
                                                                 0.406),
                                                           std=(0.229, 0.224,
                                                                0.225))
                    if batch_i == 0:
                        unnormalized_inputs = unnormalized_image.unsqueeze(0)
                    else:
                        unnormalized_inputs = torch.cat(
                            (unnormalized_inputs,
                             unnormalized_image.unsqueeze(0)),
                            dim=0)
                grid = torchvision.utils.make_grid(unnormalized_inputs)
                got_examples = True
                writer.add_image("images", grid, 0)

            if on_gpu:
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % print_interval == print_interval - 1:
                loss = running_loss / print_interval
                print(f"[{epoch + 1}, {i + 1}] loss: {loss:.6f}")
                running_loss = 0.0
                writer.add_scalar("training loss", loss,
                                  epoch * len(dataloader) + i)

        test_loss = 0
        test_correct = 0
        total_test_correct = 0
        total_img = 0
        total_runs = 0
        #print("Testing")

        for i, data in enumerate(test_dataloader, 0):
            model.eval()
            with torch.no_grad():
                inputs, labels = data
                if on_gpu:
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                outputs = model(inputs)

                loss = criterion(outputs, labels)
                test_loss += loss.item()
                softmax_output = F.softmax(outputs, dim=0)
                output_np = softmax_output.cpu().data.numpy()
                predicted_ids = output_np.argmax(1)
                labels_np = labels.cpu().data.numpy()
                correct_labels = labels_np == predicted_ids
                sum_correct_labels = correct_labels.sum()
                test_correct += sum_correct_labels

                total_correct = correct_labels
                total_correct_sum = total_correct.sum()
                total_test_correct += total_correct_sum

                total_runs += 1

        avg_test_loss = test_loss / total_runs
        test_acc = (test_correct / val_dataset_length) * 100.0

        print(f"[{epoch + 1}] Test loss: {avg_test_loss:.5f}")
        print(f"[{epoch + 1}] Test acc.: {test_acc:.3f}%")

        writer.add_scalar("test loss", avg_test_loss, epoch + 1)
        writer.add_scalar("test accuracy", test_acc, epoch + 1)

        checkpoint_name = f"epoch-{epoch + 1}-loss-{avg_test_loss:.5f}-{test_acc:.2f}.pth"
        checkpoint_full_name = os.path.join(checkpoint_dir, checkpoint_name)
        print(f"[{epoch + 1}] Saving checkpoint as {checkpoint_full_name}")
        torch.save(model.state_dict(), checkpoint_full_name)

    print("Finished training")
    writer.close()
Esempio n. 12
0
def train_vi(model,
             optimizer,
             save_path,
             train_loader,
             test_loader,
             options,
             scheduler=None):
    """ Train variational inference model """
    log = logging.getLogger(LOGGER_NAME)
    log.info("Training VI")

    # Writer where to save the log files
    os.makedirs(save_path, exist_ok=True)
    torch.save(options, os.path.join(save_path, f"optim_options.pth"))
    writer = SummaryWriter(save_path)

    # Don't use prob. functions to be able to trace back
    data = iter(train_loader).next()
    input_data = data['input'].to(DEVICE)
    if type(model) == VAE:
        writer.add_graph(model, [input_data])
    elif type(model) == CVAE:
        cond_data = data['conditional'].to(DEVICE)
        writer.add_graph(model, [input_data, cond_data])
    else:
        raise RuntimeError("Unexpected type of model")
    writer.close()

    losses = list()
    mae_list = list()
    mse_list = list()
    checkpoint_filename = os.path.join(save_path, "checkpoint.pth")
    final_filename = os.path.join(save_path, "trained_model.pth")

    if os.path.isfile(final_filename) and options['load_previous']:
        losses, mae_list, mse_list, epoch, time_optim = load_data(
            model, optimizer, final_filename, scheduler)
    else:
        if os.path.isfile(checkpoint_filename) and options['load_previous']:
            losses, mae_list, mse_list, epoch, time_optim = load_data(
                model, optimizer, checkpoint_filename, scheduler)
        else:
            epoch = 0
            time_optim = 0

        total_epochs = options.get('warm_up_kl', 0) + options.get('epochs', 0)
        for epoch in range(epoch, total_epochs):
            model.train()
            init_epoch_t = time.time()

            running_loss = 0
            running_mse = 0
            running_mae = 0
            num_batches = len(train_loader)
            for batch_idx, data in enumerate(train_loader):
                recon_data = data['input'].to(DEVICE)
                if type(model) == VAE:
                    input_data = [recon_data]
                elif type(model) == CVAE:
                    cond_data = data['conditional'].to(DEVICE)
                    input_data = [recon_data, cond_data]
                else:
                    raise RuntimeError(f"Unexpected model type: {type(model)}")

                rec_data = model.forward(*input_data)
                # with torch.autograd.detect_anomaly():
                optimizer.zero_grad()
                loss = model.loss(rec_data, recon_data)
                if epoch < options.get('warm_up_kl', 0):
                    loss['ll'].backward()
                else:
                    loss['total'].backward()
                optimizer.step()

                running_mae += mae(rec_data[2], recon_data)
                running_mse += mse(rec_data[2], recon_data)
                running_loss += loss['total'].item()

            if scheduler is not None:
                scheduler.step()

            # Save the metrics
            running_mae /= num_batches
            running_mse /= num_batches
            running_loss /= num_batches
            losses.append(running_loss)
            mae_list.append(running_mae)
            mse_list.append(running_mse)

            # Update time
            time_optim += (time.time() - init_epoch_t)

            # Print / Save info
            if (epoch + 1) % int(options['print_epochs']) == 0:
                # Metrics
                log.info(
                    f"\nEpoch: {epoch + 1}\tLoss: {running_loss:.6f}\tMSE: {running_mse:.6f}\tMAE: {running_mae:.6f}"
                )

                # Add it to the tensorboard
                for key in loss.keys():
                    writer.add_scalar(f'{key}', loss[f"{key}"], epoch)
                writer.add_scalar('MSE', running_mse, epoch)
                writer.add_scalar('MAE', running_mae, epoch)

                checkpoint = {
                    "epoch": epoch,
                    "model_state": model.state_dict(),
                    "optim_state": optimizer.state_dict(),
                    "losses": losses,
                    "metrics": {
                        'mae': mae_list,
                        'mse': mse_list
                    },
                    "training_info": model.save_training_info(),
                    "location": DEVICE,
                    "time": time_optim
                }
                if scheduler is not None:
                    checkpoint["scheduler_state"] = scheduler.state_dict()
                torch.save(checkpoint, checkpoint_filename)

                # Plot results with the test data and add them to the tensorboard
                with torch.no_grad():
                    model.eval()

                    data = iter(test_loader).next()
                    if type(model) == VAE:
                        input_data = [data['input'].to(DEVICE)]
                    elif type(model) == CVAE:
                        input_data = [
                            data['input'].to(DEVICE),
                            data['conditional'].to(DEVICE)
                        ]
                    else:
                        raise RuntimeError(
                            f"Unexpected model type: {type(model)}")

                    fig_lat = plot_latent(model, input_data, save_path=None)
                    fig_res = plot_residual(model, input_data, save_path=None)
                    fig_pred = plot_predictions(model,
                                                input_data,
                                                save_path=None)
                    fig_strip = plot_strip(model, input_data, save_path=None)

                    writer.add_figure('strip',
                                      fig_strip,
                                      global_step=epoch + 1)
                    writer.add_figure('latent', fig_lat, global_step=epoch + 1)
                    writer.add_figure('prediction',
                                      fig_pred,
                                      global_step=epoch + 1)
                    writer.add_figure('residual',
                                      fig_res,
                                      global_step=epoch + 1)
                    writer.close()

                    plt.close('all')

        # Save the final model
        checkpoint = {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optim_state": optimizer.state_dict(),
            "losses": losses,
            "metrics": {
                'mae': mae_list,
                'mse': mse_list
            },
            "training_info": model.save_training_info(),
            "location": DEVICE,
            "time": time_optim
        }
        if scheduler is not None:
            checkpoint["scheduler_state"] = scheduler.state_dict()
        torch.save(checkpoint, final_filename)

    # Save the final results in the test model
    model.eval()
    data = iter(test_loader).next()
    recon_data_test = data['input'].to(DEVICE)
    if type(model) == VAE:
        input_data = [data['input'].to(DEVICE)]
    elif type(model) == CVAE:
        input_data = [data['input'].to(DEVICE), data['conditional'].to(DEVICE)]
    else:
        raise RuntimeError(f"Unexpected model type: {type(model)}")

    rec_data = model.forward(*input_data)
    test_mae = mae(rec_data[2], recon_data_test)
    test_mse = mse(rec_data[2], recon_data_test)

    writer.add_hparams(
        options['hp_params'], {
            'mse': mse_list[-1],
            'mae': mae_list[-1],
            'loss': losses[-1],
            'test_mse': test_mse,
            'test_mae': test_mae
        })

    # Try histogram
    for name, w in model.named_parameters():
        writer.add_histogram(name, w, epoch)

    # Residual's boxplot
    fig_lat = plot_latent(model,
                          input_data,
                          save_path=os.path.join(save_path, "latent.png"))
    fig_res = plot_residual(model,
                            input_data,
                            save_path=os.path.join(save_path, "residuals.png"))
    fig_pred = plot_predictions(model,
                                input_data,
                                save_path=os.path.join(save_path,
                                                       "prediction.png"))
    fig_strip = plot_strip(model,
                           input_data,
                           save_path=os.path.join(save_path, "strip_plot.png"))

    writer.add_figure('strip', fig_strip, global_step=epoch + 1)
    writer.add_figure('latent', fig_lat, global_step=epoch + 1)
    writer.add_figure('prediction', fig_pred, global_step=epoch + 1)
    writer.add_figure('residual', fig_res, global_step=epoch + 1)

    # Plot the loss
    fig, ax = plt.subplots(1, 1)
    ax.set_title("Loss")
    ax.plot(losses)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    fig.savefig(os.path.join(save_path, "loss.png"))

    log.info(f"\n====> OPTIMIZATION FINISHED <====\n")
    log.info(
        f"\nTraining: \tLoss: {losses[-1]:.6f}\tMSE: {mse_list[-1]:.6f}\tMAE: {mae_list[-1]:.6f}"
        f"\nTest: \tMSE: {test_mse:.6f}\tMAE: {test_mae:.6f}")
Esempio n. 13
0
        return x


net = Net(1, 10, 1)
print(net)

opti = torch.optim.SGD(net.parameters(), lr=0.5)  # 优化器
loss_func = nn.MSELoss()  # 损失函数

for epoch in range(100):
    prediction = net(x)
    loss = loss_func(prediction, y)

    opti.zero_grad()  # 清除上一步更新的参数值
    loss.backward()  # 计算误差反向传递的梯度
    opti.step()  # 利用梯度更新参数

    # 保存loss和epoch数据
    writer.add_scalar('loss', loss.item(), epoch)

# 将model保存为graph
writer.add_graph(net, x)

writer.close()

# writer.add_image(tag, img)

# ======启动tensorboard=======
# cd 到runs
# tensorboard --logdir=regression_graph
# 打开 http://localhost:6006
Esempio n. 14
0
class runManager():
    def __init__(self):

        # 记录每个epoch的参数
        self.epoch_count = 0  # epoch的次数
        self.epoch_loss = 0  # 每次epoch的loss
        self.epoch_num_correct = 0  # 每次epoch正确的个数
        self.epoch_start_time = None  # epoch的起始时间

        # 记录每次运行(不同的超参数背景)
        self.run_params = None  # 超参数的数值
        self.run_count = 0  # 第几次运行,跟batch_size有关
        self.run_data = []  # 每次epoch对应的超参数的数值以及计算出的loss等
        self.run_start_time = None  # 每次运行的起始时间

        self.network = None  # 网络
        self.loader = None  # 数据
        self.tb = None  # tensorboard的写入

    # 每次运行开始需要进行的操作,需要传入一个网络和数据以及必要的超参数,放在RunBilder里面管理
    def begin_run(self, run, network, loader):
        # 起始时间
        self.run_start_time = time.time()
        # 记录此次运行的超参数
        self.run_params = run
        # 记录运行的次数
        self.run_count += 1
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'-{run}')
        # 写在tensorboard里面
        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)
        self.tb.add_image('images', grid)
        self.tb.add_graph(self.network,
                          images.to(getattr(run, 'device', 'cpu')))

    # 每次运行结束时需要进行的操作
    def end_run(self):
        # 关闭tensorboard的写操作
        self.tb.close()
        # 将epoch的次数重新归零
        self.epoch_count = 0
        # 每次epoch开始时需要进行的操作

    def begin_epoch(self):
        # 记录起始时间
        self.epoch_start_time = time.time()
        # 记录epoch的次数
        self.epoch_count += 1
        # 将epoch的loss重新归零
        self.epoch_loss = 0
        # 将epoch的正确个数重新归零
        self.epoch_num_correct = 0

    # 每次epoch结束时需要进行的操作
    def end_epoch(self):
        # 计算每次epoch完成所用的时间
        epoch_duration = time.time() - self.epoch_start_time
        # 计算每次运行(所有epoch)所用时间,这里需要注意,这里其实是在对epoch的时间经行累加
        run_duration = time.time() - self.run_start_time
        # 计算正确率
        loss = self.epoch_loss
        accuracy = self.epoch_num_correct / len(self.loader.dataset)
        # tensorboard写入数据
        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        # tensorboard写入数据
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name, param, self.epoch_count)
            # self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
        # 将结果用表格的形式可视化,每一次epoch是最小单位,所以应该在这里可视化
        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results['loss'] = loss
        results["accuracy"] = accuracy
        results['epoch duration'] = epoch_duration
        results['run duration'] = run_duration
        for k, v in self.run_params._asdict().items():
            results[k] = v
        self.run_data.append(results)
        print('runs: ' + "%d" % results["run"] + ', ' + 'epoch: ' +
              "%d" % results["epoch"] + ', ' + 'loss: ' +
              "%d" % results["loss"] + ', ' + 'accuracy: ' +
              "%f" % results["accuracy"])
        '''
        df = pd.DataFrame.from_dict(self.run_data, orient = 'columns')
        clear_output(wait=True)
        display(df)
        '''

    # 计算loss的方法,batch[0].shape[0]其实就是batch_size
    def track_loss(self, loss, batch):
        self.epoch_loss += loss.item() * batch[0].shape[0]

    # 计算正确个数的方法的方法
    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)

    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()

    # 将结果(表格)分别存为excel.csv和json格式
    def save(self, fileName):
        pd.DataFrame.from_dict(self.run_data,
                               orient='columns').to_csv(f'{fileName}.csv')

        with open(f'{fileName}.json', 'w', encoding='utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii=False, indent=4)
Esempio n. 15
0
print(f'[info] val   dataset has:{len_v} images.' )

#create tensorbordx Logger
writer = SummaryWriter(resualt_save_dir)

# https://github.com/lanpa/tensorboardX

# get some random training images save model architecture and dataset sample
dataiter = iter(train_loader)
label_front, crop_front ,label_top, meta_data = dataiter.next()

# create grid of images
img_grid = torchvision.utils.make_grid(crop_front)

writer.add_image('training_set_batches', img_grid)
writer.add_graph(model,  (crop_front ,label_front))
writer.close()

# Transfer model on the GPU/GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 


if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model ,device_ids=[0,1,2,3])


print(f"[info] Devie is:{device}")

#torch.cuda.set_device(0)
model = model.to(device) 
def main(args: Dict[str, Any]):
    start = time.time()

    # Intialize config
    config_path: str = args["config"]
    with open(config_path, "r", encoding="utf-8") as f:
        config: Dict[str, Any] = yaml.safe_load(f)
    logger.info(f"Loaded config at: {config_path}")
    logger.info(f"{pformat(config)}")


    # Initialize device
    if args["use_gpu"] and torch.cuda.is_available():
        device: torch.device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")


    # Intialize model
    model = nn.DataParallel(Resnet50(
        embedding_size=config["embedding_size"],
        pretrained=config["pretrained"]
    ))
    #checkpoint = torch.load("/home/janischl/deep-metric-learning-tsinghua-dogs/src/checkpoints/softtriple-resnet50/2021-04-11_21-31-31/epoch37-iter40000-map99.58.pth")
    #model.module.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    logger.info(f"Initialized model: {model}")


    # Initialize optimizer
    optimizer = RAdam(model.parameters(), lr=config["lr"])
    logger.info(f"Initialized optimizer: {optimizer}")


    # Initialize train transforms
    transform_train = T.Compose([
        T.Resize((config["image_size"], config["image_size"])),
        T.RandomHorizontalFlip(),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        T.RandomAffine(degrees=5, scale=(0.8, 1.2), translate=(0.2, 0.2)),
        T.ToTensor(),
        T.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])
    logger.info(f"Initialized training transforms: {transform_train}")


    # Initialize training set
    train_set = Dataset(args["train_dir"], transform=transform_train)

    if args["loss"] == "tripletloss":
        # Initialize train loader for triplet loss
        batch_size: int = config["classes_per_batch"] * config["samples_per_class"]
        train_loader = DataLoader(
            train_set,
            batch_size,
            sampler=PKSampler(
                train_set.targets,
                config["classes_per_batch"],
                config["samples_per_class"]
            ),
            shuffle=False,
            num_workers=args["n_workers"],
            pin_memory=True,
        )
        logger.info(f"Initialized train_loader: {train_loader.dataset}")

        # Intialize loss function
        loss_function = TripletMarginLoss(
            margin=config["margin"],
            sampling_type=config["sampling_type"]
        )
        logger.info(f"Initialized training loss: {loss_function}")

    elif args["loss"] == "proxy_nca":
        # Initialize train loader for proxy-nca loss
        batch_size: int = config["batch_size"]
        train_loader = DataLoader(
            train_set,
            config["batch_size"],
            shuffle=True,
            num_workers=args["n_workers"],
            pin_memory=True,
        )
        logger.info(f"Initialized train_loader: {train_loader.dataset}")

        loss_function = ProxyNCALoss(
            n_classes=len(train_set.classes),
            embedding_size=config["embedding_size"],
            embedding_scale=config["embedding_scale"],
            proxy_scale=config["proxy_scale"],
            smoothing_factor=config["smoothing_factor"],
            device=device
        )

    elif args["loss"] == "proxy_anchor":
        # Intialize train loader for proxy-anchor loss
        batch_size: int = config["batch_size"]
        train_loader = DataLoader(
            train_set,
            config["batch_size"],
            shuffle=True,
            num_workers=args["n_workers"],
            pin_memory=True,
        )
        logger.info(f"Initialized train_loader: {train_loader.dataset}")

        loss_function = ProxyAnchorLoss(
            n_classes=len(train_set.classes),
            embedding_size=config["embedding_size"],
            margin=config["margin"],
            alpha=config["alpha"],
            device=device
        )

    elif args["loss"] == "soft_triple":
        # Intialize train loader for proxy-anchor loss
        batch_size: int = config["batch_size"]
        train_loader = DataLoader(
            train_set,
            config["batch_size"],
            shuffle=True,
            num_workers=args["n_workers"],
            pin_memory=True,
        )
        logger.info(f"Initialized train_loader: {train_loader.dataset}")

        loss_function = SoftTripleLoss(
            n_classes=len(train_set.classes),
            embedding_size=config["embedding_size"],
            n_centers_per_class=config["n_centers_per_class"],
            lambda_=config["lambda"],
            gamma=config["gamma"],
            tau=config["tau"],
            margin=config["margin"],
            device=device
        )
    else:
        raise Exception("Only the following losses is supported: "
                        "['tripletloss', 'proxy_nca', 'proxy_anchor', 'soft_triple']. "
                        f"Got {args['loss']}")


    # Initialize test transforms
    transform_test = T.Compose([
        T.Resize((config["image_size"], config["image_size"])),
        T.ToTensor(),
        T.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])
    logger.info(f"Initialized test transforms: {transform_test}")


    # Initialize test set and test loader
    test_dataset = Dataset(args["test_dir"], transform=transform_test)
    test_loader = DataLoader(
        test_dataset, batch_size,
        shuffle=False,
        num_workers=args["n_workers"],
    )
    logger.info(f"Initialized test_loader: {test_loader.dataset}")


    # Initialize reference set and reference loader
    # If reference set is not given, use train set as reference set, but without random sampling
    if not args["reference_dir"]:
        reference_set = Dataset(args["train_dir"], transform=transform_test)
    else:
        reference_set = Dataset(args["reference_dir"], transform=transform_test)
    # Sometimes reference set is too large to fit into memory,
    # therefore we only sample a subset of it.
    n_samples_per_reference_class: int = args["n_samples_per_reference_class"]
    if n_samples_per_reference_class > 0:
        reference_set = get_subset_from_dataset(reference_set, n_samples_per_reference_class)

    reference_loader = DataLoader(
        reference_set, batch_size,
        shuffle=False,
        num_workers=args["n_workers"],
    )
    logger.info(f"Initialized reference set: {reference_loader.dataset}")


    # Initialize checkpointing directory
    checkpoint_dir: str = os.path.join(args["checkpoint_root_dir"], CURRENT_TIME)
    writer = SummaryWriter(log_dir=checkpoint_dir)
    logger.info(f"Created checkpoint directory at: {checkpoint_dir}")


    # Dictionary contains all metrics
    output_dict: Dict[str, Any] = {
        "total_epoch": args["n_epochs"],
        "current_epoch": 0,
        "current_iter": 0,
        "metrics": {
            "mean_average_precision": 0.0,
            "average_precision_at_1": 0.0,
            "average_precision_at_5": 0.0,
            "average_precision_at_10": 0.0,
            "top_1_accuracy": 0.0,
            "top_5_accuracy": 0.0,
            "normalized_mutual_information": 0.0,
        }
    }
    # Start training and testing
    logger.info("Start training...")
    for _ in range(1, args["n_epochs"] + 1):
        output_dict = train_one_epoch(
            model, optimizer, loss_function,
            train_loader, test_loader, reference_loader,
            writer, device, config,
            checkpoint_dir,
            args['log_frequency'],
            args['validate_frequency'],
            output_dict
        )
    logger.info(f"DONE TRAINING {args['n_epochs']} epochs")


    # Visualize embeddings
    #logger.info("Calculating train embeddings for visualization...")
    log_embeddings_to_tensorboard(train_loader, model, device, writer, tag="train")
    # logger.info("Calculating reference embeddings for visualization...")
    log_embeddings_to_tensorboard(reference_loader, model, device, writer, tag="reference")
    # logger.info("Calculating test embeddings for visualization...")
    log_embeddings_to_tensorboard(test_loader, model, device, writer, tag="test")


    # Visualize model's graph
    logger.info("Adding graph for visualization")
    with torch.no_grad():
        dummy_input = torch.zeros(1, 3, config["image_size"], config["image_size"]).to(device)
        writer.add_graph(model.module.features, dummy_input)


    # Save all hyper-parameters and corresponding metrics
    logger.info("Saving all hyper-parameters")
    writer.add_hparams(
        config,
        metric_dict={f"hyperparams/{key}": value for key, value in output_dict["metrics"].items()}
    )
    with open(os.path.join(checkpoint_dir, "output_dict.json"), "w") as f:
        json.dump(output_dict, f, indent=4)
    logger.info(f"Dumped output_dict.json at {checkpoint_dir}")


    end = time.time()
    logger.info(f"EVERYTHING IS DONE. Training time: {round(end - start, 2)} seconds")
Esempio n. 17
0
    }
    agent = Agent(params)

    if not os.path.isdir(params['path_logs_dir']):
        os.makedirs(params['path_logs_dir'])
    shutil.copy('./params.json', params['path_logs_dir'] + '/params.json')
    writer = SummaryWriter(params['path_logs_dir'])
    dummy_input_to_policy_net = torch.randn(
        1, json_params['size_resized_image'],
        json_params['size_resized_image']).float().to(
            params['device']).unsqueeze(0)
    dummy_input_to_target_net = torch.randn(
        1, json_params['size_resized_image'],
        json_params['size_resized_image']).float().to(
            params['device']).unsqueeze(0)
    writer.add_graph(agent.brain.policy_net, dummy_input_to_policy_net)
    writer.add_graph(agent.brain.target_net, dummy_input_to_target_net)

    for episode in range(1, json_params['num_episodes'] + 1):
        observation = env.reset()
        state = preprocess(observation, json_params['size_resized_image'])
        t = done = total_rewards = total_loss = total_max_q_val = 0

        while True:
            if json_params['render']:
                env.render()

            t += 1
            action = agent.get_action(state)
            observation, reward, done, _ = env.step(action)
            if done:
def main():
    warnings.simplefilter(action='ignore', category=FutureWarning)

    args,arg_groups = ArgumentParser(mode='train').parse()
    print(args)

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    # Seed
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    # Download dataset if needed and set paths
    if args.training_dataset == 'pascal':

        if args.dataset_image_path == '' and not os.path.exists('datasets/pascal-voc11/TrainVal'):
            download_pascal('datasets/pascal-voc11/')

        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/pascal-voc11/'

        args.dataset_csv_path = 'training_data/pascal-random'

#------------- RGB 512
    if args.training_dataset == 'rgb512_aug':

        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/rgb512_augmented/'

        args.dataset_csv_path = 'training_data/rgb512_augmented-random' 
        
#######################
    if args.training_dataset == 'rgb240_aug':
        
        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/rgb240_augmented/'

        args.dataset_csv_path = 'training_data/rgb240_augmented-random' 
#######################
    if args.training_dataset == 'red240_aug':
        
        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/red240_augmented/'

        args.dataset_csv_path = 'training_data/red240_augmented-random' 

    if args.training_dataset == 'smallset':
        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/smallset/'

        args.dataset_csv_path = 'training_data/smallset-random' 

    # CNN model and loss
    print('Creating CNN model...')
    if args.geometric_model=='affine':
        cnn_output_dim = 6
    elif args.geometric_model=='hom' and args.four_point_hom:
        cnn_output_dim = 8
    elif args.geometric_model=='hom' and not args.four_point_hom:
        cnn_output_dim = 9
    elif args.geometric_model=='tps':
        cnn_output_dim = 18
    
##############################
    model = CNNGeometric(use_cuda=use_cuda,
                         output_dim=cnn_output_dim,
                         **arg_groups['model'])

#######################
    if args.geometric_model=='hom' and not args.four_point_hom:
        init_theta = torch.tensor([1,0,0,0,1,0,0,0,1], device = device)
        model.FeatureRegression.linear.bias.data+=init_theta
        
    if args.geometric_model=='hom' and args.four_point_hom:
        init_theta = torch.tensor([-1, -1, 1, 1, -1, 1, -1, 1], device = device)
        model.FeatureRegression.linear.bias.data+=init_theta

    if args.use_mse_loss:
        print('Using MSE loss...')
        loss = nn.MSELoss()
    else:
        print('Using grid loss...')
        loss = TransformedGridLoss(use_cuda=use_cuda,
                                   geometric_model=args.geometric_model)

    # Initialize Dataset objects
    dataset = SynthDataset(geometric_model=args.geometric_model,
               dataset_csv_path=args.dataset_csv_path,
               dataset_csv_file='train.csv',
			   dataset_image_path=args.dataset_image_path,
			   transform=NormalizeImageDict(['image']),
			   random_sample=args.random_sample)

    dataset_val = SynthDataset(geometric_model=args.geometric_model,
                   dataset_csv_path=args.dataset_csv_path,
                   dataset_csv_file='val.csv',
			       dataset_image_path=args.dataset_image_path,
			       transform=NormalizeImageDict(['image']),
			       random_sample=args.random_sample)

    # Set Tnf pair generation func
    pair_generation_tnf = SynthPairTnf(geometric_model=args.geometric_model,
				       use_cuda=use_cuda)

    # Initialize DataLoaders
    dataloader = DataLoader(dataset, batch_size=args.batch_size,
                            shuffle=True, num_workers=4)

    dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size,
                                shuffle=True, num_workers=4)

    # Optimizer and eventual scheduler
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=args.lr_max_iter,
                                                               eta_min=5e-7)
    else:
        scheduler = False

    # Train

    # Set up names for checkpoints
    if args.use_mse_loss:
        ckpt = args.training_dataset + '_' + args.trained_model_fn + '_' + args.geometric_model + '_mse_loss' + args.feature_extraction_cnn
        checkpoint_path = os.path.join(args.trained_model_dir,
                                       args.trained_model_fn,
                                       ckpt + '.pth.tar')
    else:
        ckpt = args.trained_model_fn + '_' + args.geometric_model+args.feature_regression+ args.feature_extraction_cnn+args.training_dataset + '_' + '_grid_loss'
        checkpoint_path = os.path.join(args.trained_model_dir,
                                       args.trained_model_fn,
                                       ckpt + '.pth.tar')
    if not os.path.exists(args.trained_model_dir):
        os.mkdir(args.trained_model_dir)

    # Set up TensorBoard writer
    if not args.log_dir:
        tb_dir = os.path.join(args.trained_model_dir, args.trained_model_fn + '_tb_logs')
    else:
        tb_dir = os.path.join(args.log_dir, args.trained_model_fn + '_tb_logs')

    logs_writer = SummaryWriter(tb_dir)
    # add graph, to do so we have to generate a dummy input to pass along with the graph
    dummy_input = {'source_image': torch.rand([args.batch_size, 3, 240, 240], device = device),
                   'target_image': torch.rand([args.batch_size, 3, 240, 240], device = device),
                   'theta_GT': torch.rand([16, 2, 3], device = device)}

    logs_writer.add_graph(model, dummy_input)

    # Start of training
    print('Starting training...')

    best_val_loss = float("inf")
    df = pd.DataFrame()
    for epoch in range(1, args.num_epochs+1):

        train_loss = train(epoch, model, loss, optimizer,
                  dataloader, pair_generation_tnf,
                  log_interval=args.log_interval,
                  scheduler=scheduler,
                  tb_writer=logs_writer)

        val_loss = validate_model(model, loss,
                                  dataloader_val, pair_generation_tnf,
                                  epoch, logs_writer)
        
        #Logging losses to .csv so we can re-use them later for graphs
        df = df.append({'epoch': epoch, 'train_loss': train_loss,'val_loss' : val_loss}, ignore_index=True)
        # remember best loss
        is_best = val_loss < best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        save_checkpoint({
                         'epoch': epoch + 1,
                         'args': args,
                         'state_dict': model.state_dict(),
                         'best_val_loss': best_val_loss,
                         'optimizer': optimizer.state_dict(),
                         },
                        is_best, checkpoint_path)
        
    name = args.geometric_model+'_'+args.feature_extraction_cnn+'_'+args.feature_regression+'dropout_'+str(args.fr_dropout)+'.csv'
    csv_path = os.path.join(args.trained_model_dir,
                            args.trained_model_fn,
                            name)
    df.to_csv(csv_path)
    logs_writer.close()
    print('Done!')
Esempio n. 19
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    images = next(iter(train_loader))['GT']
    print(images)

    grid = torchvision.utils.make_grid(images)
    tb_logger.add_image('images', grid, 0)
    tb_logger.add_graph(model.netG.module, images.cuda())

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    images = val_data['GT']
                    grid = torchvision.utils.make_grid(images)
                    tb_logger.add_image('images', grid, 0)
                    tb_logger.add_graph(model.netG.module, images.cuda())

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    lr_img = util.tensor2img(visuals['LR'])

                    gtl_img = util.tensor2img(visuals['LR_ref'])

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # Save LR images
                    save_img_path_L = os.path.join(
                        img_dir,
                        '{:s}_forwLR_{:d}.png'.format(img_name, current_step))
                    util.save_img(lr_img, save_img_path_L)

                    # Save ground truth
                    if current_step == opt['train']['val_freq']:
                        save_img_path_gt = os.path.join(
                            img_dir,
                            '{:s}_GT_{:d}.png'.format(img_name, current_step))
                        util.save_img(gt_img, save_img_path_gt)
                        save_img_path_gtl = os.path.join(
                            img_dir, '{:s}_LR_ref_{:d}.png'.format(
                                img_name, current_step))
                        util.save_img(gtl_img, save_img_path_gtl)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}.'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}.'.format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
Esempio n. 20
0
# do text parsing, get vocab size and class count
build_vocab(args.train, args.output_vocab_label, args.output_vocab_word)
label2id, id2label = load_vocab(args.output_vocab_label)
word2id, id2word = load_vocab(args.output_vocab_word)

vocab_size = len(word2id)
num_class = len(label2id)

# set model
model = AttentionBiLSTM(vocab_size=vocab_size, num_class=num_class, emb_dim=args.embedding_dim, emb_droprate=args.embedding_droprate, sequence_len=args.sequence_len, att_droprate=args.att_droprate, rnn_cell_hidden=args.rnn_cell_hidden, num_layers=args.num_layers, att_method=args.att_method)
model.build()
model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)
writer.add_graph(model, torch.randint(low=0, high=1000, size=(args.batch_size, args.sequence_len), dtype=torch.long).to(device))
print(summary(model, torch.randint(low=0, high=1000, size=(args.batch_size, args.sequence_len), dtype=torch.long).to(device)))

# padding sequence with <PAD>
def padding(data, fix_length, pad, add_first="", add_last=""):
    if add_first:
        data.insert(0, add_first)
    if add_last:
        data.append(add_last)
    pad_data = []
    data_len = len(data)
    for idx in range(fix_length):
        if idx < data_len:
            pad_data.append(data[idx])
        else:
            pad_data.append(pad)
class Visualizer():
    def __init__(self,
                 model,
                 name="DeepFaceDrawing",
                 default_dir='runs',
                 display_architecture=False,
                 display_in_train=True):
        self.writer = SummaryWriter(
            os.path.join(
                default_dir,
                name + '_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")))
        self.model = model

        if display_architecture:
            self.inspect_model()

        if display_in_train:
            p = os.path.join(self.model.save_dir, 'display_in_train')
            if not os.path.exists(p):
                os.mkdir(p)
            self.path_display_in_train = p
            self.html = HTML(p, 'display_in_train_html')

            # create a logging file to store training losses
        self.log_name = os.path.join(self.model.save_dir, 'loss_log.txt')
        with open(self.log_name, "a") as log_file:
            now = time.strftime("%c")
            log_file.write(
                '================ Training Loss (%s) ================\n' % now)

    def inspect_model(self):
        self.writer.add_graph(self.model)

    def plot_loss(self, n_iter, name='Loss/train'):
        loss = self.model.get_current_losses()
        if isinstance(loss, dict):
            self.writer.add_scalars(name, loss, n_iter)
        else:
            self.writer.add_scalar(name, loss, n_iter)

    def display_current_results(self, epoch):
        visual_im = self.model.get_current_visuals()
        self.html.add_header('epoch ' + str(epoch))
        ims, txts, links = [], [], []
        for key in visual_im.keys():
            link = os.path.join(self.path_display_in_train, 'images',
                                str(epoch) + '_' + key + '.jpg')
            torchvision.utils.save_image(visual_im[key], link)

            link = str(epoch) + '_' + key + '.jpg'
            ims.append(link)
            txts.append(key)
            links.append(link)

        self.html.add_images(ims, txts, links)
        self.html.save()

    # losses: same format as |losses| of plot_current_losses
    def print_current_losses(self, epoch, iters, losses, t_comp):
        """print current losses on console; also save the losses to the disk

        Parameters:
            epoch (int) -- current epoch
            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
            t_comp (float) -- computational time per data point (normalized by batch_size)
            t_data (float) -- data loading time per data point (normalized by batch_size)
        """
        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, iters,
                                                           t_comp)
        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)

        print(message)  # print the message
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)  # save the message
Esempio n. 22
0
'''
# 2. Writing to tensorboard
#get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
#create grid of images
img_grid = torchvision.utils.make_grid(images)
#show images
matplotlib_imshow(img_grid, one_channel=True)
#write to tensorboard
writer.add_image('four_fashion_mnist_images', img_grid)
''' 
- Graphs
'''
# 3. Inspect the model using TensorBoard
writer.add_graph(net, images.to(device))
writer.close()
''' 
- Projector
'''


# 4. Adding a "Projector" to TensorBoard.
# what is projector?
def select_n_random(data, labels, n=100):
    '''
    Selects n random datapoints and their corresponding labels from a dataset
    '''
    assert len(data) == len(labels)

    perm = torch.randperm(len(data))
Esempio n. 23
0
class Trainer():
    def __init__(self,
                 generator,
                 moment_network,
                 train_set,
                 training_params,
                 device=None,
                 scores=None,
                 tensorboard=False,
                 save_folder="runs/run"):
        """
            generator: a nn.Module child class serving as a generator network
            moment_network: a nn.Module child class serving as the moment network
            loader: a training data loader
            scores: None, or a dict of shape {'name':obj} with score object with a __call__ function that returns a score
            
            training_params: dict of training parameters with:
                n0: number of objectives
                nm: number of moments trainig step
                ng: number of generating training steps
                lr: learning rate
                beta1 / beta2: Adam parameters 
                acw: activation wieghts
                alpha: the norm penalty parameter
                gen_batch_size: the batch size to train the generator
                mom_batch_size: the batch size to train the moment network
                eval_batch_size: the batch size to evaluate the generated
                eval_size: total number of generated samples on which to evaluate the scores
            
            tensorboard: whether to use tensorboard to save training information
            save_folder: root folder to save the training information

        """
        self.G = generator
        self.MoNet = moment_network
        self.train_set = train_set
        self.training_params = training_params
        self.nm = training_params["nm"]
        self.ng = training_params["ng"]
        self.no = training_params["no"]
        self.no_obj = 0  #current objective
        self.n_moments = training_params["n_moments"]
        self.gen_batch_size = training_params["gen_batch_size"]
        self.eval_batch_size = training_params["eval_batch_size"]
        self.learn_moments = training_params["learn_moments"]

        lr, beta1, beta2 = self.training_params["lr"], self.training_params[
            "beta1"], self.training_params["beta2"]
        self.optimizerG = optim.Adam(self.G.parameters(),
                                     lr=lr,
                                     betas=(beta1, beta2))
        self.optimizerM = optim.Adam(self.MoNet.parameters(),
                                     lr=lr,
                                     betas=(beta1, beta2))

        self.LM = []
        self.LG = []
        self.iter = 0
        self.device = device

        self.cross_entropy = F.binary_cross_entropy
        self.mse = MSELoss(reduction="sum")

        #to track the evolution of generated images from a single batch of noises
        self.fixed_z = torch.randn(20, self.G.dims[0], device=self.device)

        #monitoring the progress of the training with the evaluation scores
        self.scores = scores

        #saving training info
        self.run_folder = save_folder
        self.save_path_img = self.run_folder + "/results/images/"
        self.save_path_checkpoints = self.run_folder + "/checkpoints/"

        #monitoring through tensorboard
        if tensorboard:
            comment = ''.join([
                '{}={} '.format(key, training_params[key])
                for key in training_params
            ])
            self.tb = SummaryWriter(self.run_folder, comment=comment)
            self.tb.add_graph(generator, self.fixed_z)

    def train_monet(self):
        #reshuffle training data
        loader = iter(
            torch.utils.data.DataLoader(
                self.train_set,
                shuffle=True,
                batch_size=self.training_params["mom_batch_size"]))
        for i in range(self.nm):
            batch = loader.next()
            samples, _ = batch
            samples = samples.to(self.device)
            samples = (samples * 2) - 1

            sample_size = samples.size(0)
            one_labels = torch.ones(sample_size, device=self.device)
            zero_labels = torch.zeros(sample_size, device=self.device)

            #generating latent vector
            #self.dims = [Z_dim, H1_dim, H2_dim, H3_dim, X_dim]
            z = torch.randn(sample_size, self.G.dims[0], device=self.device)
            res = self.G(z)
            prob_trues, output_trues, _ = self.MoNet(samples)
            prob_gen, output_gen, _ = self.MoNet(res)

            prob_trues, prob_gen = prob_trues.squeeze(), prob_gen.squeeze()
            LM_samples = self.cross_entropy(prob_trues, one_labels)
            LM_gen = self.cross_entropy(prob_gen, zero_labels)
            LM = LM_samples + LM_gen

            #We now need to compute the gradients to add the regularization term
            mean_output = output_trues.mean()
            self.optimizerM.zero_grad()
            grad_monet = self.MoNet.get_gradients(mean_output)
            #This is the sum of gradients, so we divide by the batch size
            grad_monet = (grad_monet / sample_size).squeeze()
            grad_norm = torch.dot(grad_monet, grad_monet)
            LM = LM_samples + LM_gen + self.training_params["alpha"] * (
                (grad_norm - 1)**2)
            #LM = LM_samples + LM_gen
            #print("LM loss: {:.4}".format(float(LM)))
            #Add to tensorboard
            if self.tb:
                self.tb.add_scalar(
                    'LossMonet/objective_{}'.format(self.no_obj + 1),
                    float(LM), i + 1)
            self.LM.append(float(LM))
            if i % 50 == 0:
                logger.info("Moment Network Iteration {}/{}: LM: {:.6}".format(
                    i + 1, self.nm, LM.item()))

            self.optimizerM.zero_grad()
            LM.backward()
            self.optimizerM.step()

            del grad_monet
            del batch

    def eval_true_moments(self):
        loader = torch.utils.data.DataLoader(
            self.train_set,
            shuffle=True,
            batch_size=self.training_params["mom_batch_size"])
        #Calculate the moment vector over the entire dataset:
        moments = torch.zeros(self.n_moments, device=self.device)
        for i, batch in enumerate(loader):
            #if i % 100 == 0:
            #   print("Computing real data Moment Features... {}/{}".format(i+1, len(loader)))
            samples, _ = batch
            samples = samples.to(self.device)
            sample_size = samples.size(0)
            #Scaling true images to tanh activation interval:
            samples = (samples * 2) - 1
            self.optimizerM.zero_grad()
            moments_b = self.MoNet.get_moment_vector(
                samples,
                sample_size,
                weights=self.training_params["activation_weight"],
                detach=True)
            moments = ((i) * moments + moments_b) / (i + 1)
            del batch
            del samples
            del moments_b
        return moments

    def train_generator(self, true_moments):
        for i in range(self.ng):
            #moments_gz = torch.zeros(n_moments, device=self.device)
            #print(i)
            #if i%225 ==0 :
            #   print("Computing Monte Carlo estimate of generated data Moment Features... {}/{}".format(i+1, 5))

            z = torch.randn(self.gen_batch_size,
                            self.G.dims[0],
                            device=self.device)
            res = self.G(z)
            self.optimizerM.zero_grad()
            moments_gz = self.MoNet.get_moment_vector(
                res,
                self.gen_batch_size,
                weights=self.training_params["activation_weight"])
            #moments_gz = ((i) * moments_gz + moments_z) / (i+1)

            del z
            del res

            #LG = torch.dot(true_moments - moments_gz, true_moments - moments_gz) #equivalent to dot product of difference
            LG = self.mse(true_moments, moments_gz)
            #Add to tensorboard
            if self.tb:
                self.tb.add_scalar(
                    'LossGenerator/objective_{}'.format(self.no_obj + 1),
                    float(LG), i + 1)
            self.LG.append(float(LG))
            if i % 100 == 0:
                logger.info("Generator Iteration {}/{}: LG: {:.6}".format(
                    i + 1, self.ng, LG.item()))
            self.optimizerG.zero_grad()
            LG.backward()
            self.optimizerG.step()

            del moments_gz

    def generate_and_display(self, z, save=False, save_path=None):
        #Visualizing the generated images
        examples = self.G(z).detach().cpu()
        examples = examples.reshape(-1, 3, 32, 32)
        examples = (examples + 1) / 2
        grid = torchvision.utils.make_grid(examples,
                                           nrow=10)  # 10 images per row
        #Add to tensorboard
        if self.tb:
            self.tb.add_image('generated images', grid, self.no_obj)
        fig = plt.figure(figsize=(15, 15))
        plt.imshow(np.transpose(grid, (1, 2, 0)))
        if save:
            plt.savefig(save_path)
        else:
            plt.show()
        plt.close(fig)

    def eval(self):
        logger.info("Evaluating generated images with scores: {}".format(
            self.scores.keys()))
        scores_dict = self.scores
        n_loops = self.training_params["eval_size"] // self.eval_batch_size
        results = dict(zip(scores_dict.keys(), [0] * len(scores_dict)))
        for i in range(n_loops):
            with torch.no_grad():
                z = torch.randn(self.eval_batch_size,
                                self.G.dims[0],
                                device=self.device)
                res = self.G(z).cpu()
            samples = InceptionScore.preprocess(res)
            for score in scores_dict:
                scoring = scores_dict[score]
                results[score] += scoring(samples)
        for score in scores_dict:
            results[score] /= n_loops
        return results

    def load_from_checkpoints(self, path):
        """
        Loads network parameters and training info from checkpoint
            path: path to checkpoint
        """
        logger.info(
            "Loading network parameters and training info from checkpoint...")
        checkpoint = torch.load(path)
        self.G.load_state_dict(checkpoint['generator_state_dict'])
        self.optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        self.G.train()

        if self.learn_moments:
            self.MoNet.load_state_dict(checkpoint['monet_state_dict'])
            self.optimizerM.load_state_dict(
                checkpoint['optimizerM_state_dict'])
            self.MoNet.train()

        last_objective = checkpoint['objective']
        lossG = checkpoint['last_lossG']
        lossM = checkpoint['last_lossM']

        return last_objective, lossG, lossM

    def train(self, save_images=False, from_checkpoint=None):
        if not self.learn_moments:
            true_moments = self.eval_true_moments()

        if from_checkpoint:
            last_objective, lossG, lossM = self.load_from_checkpoints(
                from_checkpoint)
            logger.info(
                "Starting training from Objective: {}, lossG: {}, lossM: {}".
                format(last_objective, lossG, lossM))

        for i in range(last_objective, self.no):
            #Track the no of objectives solved
            self.no_obj = i

            start = time.time()
            if self.learn_moments:
                logger.info("Training Moment Network...")
                self.train_monet()
                logger.info("Evaluating true moments value...")
                true_moments = self.eval_true_moments()
            logger.info("Training Generator")
            self.train_generator(true_moments)
            self.iter += 1
            stop = time.time()
            duration = (stop - start) / 60

            if self.learn_moments:
                logger.info(
                    "Objective {}/{} - {:.2} minutes: LossMonet: {:.6} LossG: {:.6}"
                    .format(i + 1, self.no, duration, self.LM[-1],
                            self.LG[-1]))
            else:
                logger.info(
                    "Objective {}/{} - {:.2} minutes: LossG: {:.6}".format(
                        i + 1, self.no, duration, self.LG[-1]))

            self.generate_and_display(
                self.fixed_z,
                save=save_images,
                save_path=self.save_path_img +
                "generated_molm_cifar10_iter{}.png".format(i))

            if i % SAVING_FREQUENCY == 0:
                logger.info("Saving model ...")
                save_path_checkpoints = self.save_path_checkpoints + "molm_cifar10_iter{}.pt".format(
                    i)
                save_dict = {
                    'monet_state_dict': self.MoNet.state_dict(),
                    'generator_state_dict': self.G.state_dict(),
                    'optimizerG_state_dict': self.optimizerG.state_dict(),
                    'objective': i + 1,
                    'last_lossG': self.LG[-1]
                }
                if self.learn_moments:
                    save_dict["last_lossM"] = self.LM[-1]
                    save_dict[
                        "optimizerM_state_dict"] = self.optimizerM.state_dict(
                        )

                torch.save(save_dict, save_path_checkpoints)

                if self.scores:
                    scores = self.eval()
                    logger.info(scores)
                    #Add to tensorboard
                    if self.tb:
                        for score in scores:
                            self.tb.add_scalar('Scores/{}'.format(score),
                                               scores[score], i + 1)

            # Updating data on tensorboard
            if self.tb:
                for name, param in self.G.named_parameters():
                    self.tb.add_histogram('generator.{}'.format(name), param,
                                          i + 1)
                    self.tb.add_histogram('generator.{}.grad'.format(name),
                                          param.grad, i + 1)
                for name, param in self.MoNet.named_parameters():
                    self.tb.add_histogram('momentNetwork.{}'.format(name),
                                          param, i + 1)
                    self.tb.add_histogram('momentNetwork.{}.grad'.format(name),
                                          param.grad, i + 1)
Esempio n. 24
0
    else:
        checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net,
                                       settings.TIME_NOW)

    #use tensorboard
    if not os.path.exists(settings.LOG_DIR):
        os.mkdir(settings.LOG_DIR)

    #since tensorboard can't overwrite old values
    #so the only way is to create a new tensorboard log
    writer = SummaryWriter(
        log_dir=os.path.join(settings.LOG_DIR, args.net, settings.TIME_NOW))
    input_tensor = torch.Tensor(1, 3, 32, 32)
    if args.gpu:
        input_tensor = input_tensor.cuda()
    writer.add_graph(net, input_tensor)

    #create checkpoint folder to save model
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')

    best_acc = 0.0
    if args.resume:
        best_weights = best_acc_weights(
            os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder))
        if best_weights:
            weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net,
                                        recent_folder, best_weights)
            print('found best acc weights file:{}'.format(weights_path))
            print('load best training file to test acc...')
def train(epoch=5, freeze=True):

    tb = SummaryWriter()
    #Defining Model
    model = Encoder()
    print(model)

    #model.load_state_dict(torch.load(load_path))

    if freeze:
        for param in model._resnet_extractor.parameters():
            param.require_grad = False

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(),
                          lr=learning_rate,
                          momentum=momentum)

    transform = set_transform()
    train_loader = get_loader(train_csv, batch_size, transform=transform)

    valid_loader = get_loader(valid_csv, batch_size, transform=transform)

    img, cls = next(iter(train_loader))

    #print(img.shape)
    grid = torchvision.utils.make_grid(img)
    tb.add_image('images', grid, 0)
    # tb.add_graph(model,img[0])

    if torch.cuda.is_available():
        model = model.cuda()
        img = img.cuda()

    total_train_loss = []
    total_val_loss = []

    best_train = 100000000
    best_valid = 100000000
    not_improve = 0

    #train_avg_list = []
    #valid_avg_list = []
    tb.add_graph(model, img)
    for e in range(1, epoch):

        loss_train = []
        loss_val = 0
        acc_train = 0
        acc_val = 0

        model.train()
        num_iter = 1
        for i, (images, classes) in enumerate(train_loader):

            optimizer.zero_grad()

            if torch.cuda.is_available():
                images = images.cuda()
                classes = classes.cuda()

            feature_image = model(images)
            _, preds = torch.max(feature_image.data, 1)

            loss = criterion(feature_image, classes)

            loss.backward()

            optimizer.step()

            loss_train.append(loss.cpu().detach().numpy())
            acc_train += torch.sum(preds == classes)

            del feature_image, classes, preds
            torch.cuda.empty_cache()

            #print(f"Loss i: {i}")
            num_iter = i + 1
            if i % 10 == 0:
                print(f"Epoch ({e}/{epoch}) Iter: {i+1} Loss: {loss}")

        avg_loss = sum(loss_train) / num_iter
        print(f"\t\tTotal iter: {num_iter} AVG loss: {avg_loss}")
        tb.add_scalar("Train_Loss", avg_loss, e)
        tb.add_scalar("Train_Accuracy", 100 - avg_loss, e)

        total_train_loss.append(avg_loss)

        model.eval()
        num_iter_val = 1
        for i, (images, classes) in enumerate(valid_loader):

            optimizer.zero_grad()

            feature_image = model(images)

            if torch.cuda.is_available():
                feature_image = feature_image.cuda()
                classes = classes.cuda()

            _, preds = torch.max(feature_image.data, 1)

            loss = criterion(feature_image, classes)

            loss_val += loss.cpu().detach().numpy()
            acc_val += torch.sum(preds == classes)

            num_iter_val = i + 1
            del feature_image, classes, preds
            torch.cuda.empty_cache()

        avg_val = loss_val / num_iter_val
        print(f"\t\tValid Loss: {avg_val}")

        tb.add_scalar("Validation_Loss", avg_val, e)
        tb.add_scalar("Validation_Accuracy", 100 - avg_val, e)

        if avg_val < best_valid:
            total_val_loss.append(avg_val)
            model_save = save_path + "/best_model.th"
            torch.save(model.state_dict(), model_save)
            best_valid = avg_val
            print(f"Model saved to path save/")
            not_improve = 0

        else:
            not_improve += 1
            print(f"Not Improved {not_improve} times ")
        if not_improve == 6:
            break

    save_loss = {"train": total_train_loss, "valid": total_val_loss}
    with open(save_path + "/losses.pickle", "wb") as files:
        pickle.dump(save_loss, files)

    tb.close()
Esempio n. 26
0
                'Train Epoch: {:>3} [{:>5}/{:>5} ({:>3.0f}%)]\ttrain loss: {:>2.4f}\tmean val loss: {:>2.4f}\tmean '
                'dice coefficient: {:>2.4f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item(),
                    mean_val_loss, mean_dice_coeff))


def save(epoch):
    checkpoint_path = os.path.join(base_path, "checkpoints")
    save_file = "checkpoint.pth.tar"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    if not os.path.exists(os.path.join(checkpoint_path, log_name)):
        os.makedirs(os.path.join(checkpoint_path, log_name))
    save_path = os.path.join(checkpoint_path, log_name, save_file)
    torch.save(model.state_dict(), save_path)


if __name__ == "__main__":
    model_load = False
    if model_load == True:
        start_epoch = 52
        epoch_range = range(start_epoch, epochs + 1)
    else:
        epoch_range = range(1, epochs + 1)
    dummy = torch.rand(4, 3, 256, 256).cuda()
    writer.add_graph(model, (dummy, ))
    for epoch in epoch_range:
        train(epoch)
        save(epoch)
Esempio n. 27
0
        g_writer = SummaryWriter(os.path.join(SAVEPATH, "tb_g_net"))
        l_writer_train = SummaryWriter(os.path.join(SAVEPATH,
                                                    "tb_l_net_train"))
        l_writer_eval = SummaryWriter(os.path.join(SAVEPATH, "tb_l_net_eval"))
        #generate random trajectory to feed-forward as batch
        w0 = workers[0]
        env = gym.make('gym_boxworld:boxworld-v0', **random_config())

        env.reset()

        trajectory = []
        while True:
            s, _, done, _ = env.step(np.random.choice(4))
            s = torch.tensor([s.T], dtype=torch.float)
            trajectory.append(s)
            if done:
                break
        # write graph to file
        s_ = torch.cat(trajectory).detach()

        g_writer.add_graph(g_net, s_)
        w0.l_net.eval()
        l_writer_eval.add_graph(w0.l_net, s_)
        w0.l_net.train()
        l_writer_train.add_graph(w0.l_net, s_)
        g_writer.close()
        l_writer_eval.close()
        l_writer_train.close()

        ###visualize gradients at critical points
        outputs = model(inputs)
        batch_loss = cross_entropy(outputs, labels)
        batch_loss.backward()
        adam.step()

        if step == 1:
            save_images = inputs

        if step % 100 == 99:

            correct = 0
            total = 0
            with torch.no_grad():
                for data in loader_test:
                    images, labels = data
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            print(('Epoch: %d, Step %5d, Images: %5d  loss: ' +
                   '%3f  test-set accuracy: %5.2f') %
                  (epoch + 1, step + 1, (step + 1) * batch_size,
                   batch_loss.item(), model.run_inference(loader_test)))
            accumulated_loss = 0.0

torch.save(model.state_dict(), model.save_filename())

writer.add_graph(model, save_images)
writer.close()
# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()

# create grid of images
img_grid = torchvision.utils.make_grid(images)

# show images
matplotlib_imshow(img_grid, one_channel=True)

# write to tensorboard
writer.add_image('four_fashion_mnist_images', img_grid)

# write network to to tensorboard
writer.add_graph(
    net, images
)  # model is moved to 'cpu' to make sure that the model is in local device for tensorboard'
writer.close()

# Set loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# helper functions


def images_to_probs(net, images):
    '''
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    '''
def train_variant(conv, fcl, args):

    net, arch_name = construct_vgg_variant(conv_variant=conv,
                                           fcl_variant=fcl,
                                           batch_norm=True,
                                           progress=True,
                                           pretrained=False)
    args.net = arch_name
    if args.gpu:  #use_gpu
        net = net.cuda()

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    train_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=settings.MILESTONES,
        gamma=0.2)  # learning rate decay
    iter_per_epoch = len(cifar100_training_loader)
    warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    if args.resume:
        recent_folder = most_recent_folder(os.path.join(
            settings.CHECKPOINT_PATH, args.net),
                                           fmt=settings.DATE_FORMAT)
        if not recent_folder:
            raise Exception('no recent folder were found')

        checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net,
                                       recent_folder)

    else:
        checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net,
                                       settings.TIME_NOW)

    #use tensorboard
    if not os.path.exists(settings.LOG_DIR):
        os.mkdir(settings.LOG_DIR)

    #since tensorboard can't overwrite old values
    #so the only way is to create a new tensorboard log
    writer = SummaryWriter(
        log_dir=os.path.join(settings.LOG_DIR, args.net, settings.TIME_NOW))
    if args.gpu:
        input_tensor = torch.Tensor(1, 3, 32, 32).cuda()
    else:
        input_tensor = torch.Tensor(1, 3, 32, 32)
    writer.add_graph(net, input_tensor)

    #create checkpoint folder to save model
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')

    best_acc = 0.0
    if args.resume:
        best_weights = best_acc_weights(
            os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder))
        if best_weights:
            weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net,
                                        recent_folder, best_weights)
            print('found best acc weights file:{}'.format(weights_path))
            print('load best training file to test acc...')
            net.load_state_dict(torch.load(weights_path))
            best_acc = eval_training(tb=False)
            print('best acc is {:0.2f}'.format(best_acc))

        recent_weights_file = most_recent_weights(
            os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder))
        if not recent_weights_file:
            raise Exception('no recent weights file were found')
        weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net,
                                    recent_folder, recent_weights_file)
        print('loading weights file {} to resume training.....'.format(
            weights_path))
        net.load_state_dict(torch.load(weights_path))

        resume_epoch = last_epoch(
            os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder))

    train_params = {
        'net': net,
        'warmup_scheduler': warmup_scheduler,
        'loss_function': loss_function,
        'optimizer': optimizer,
        'writer': writer
    }
    for epoch in range(1, settings.EPOCH):
        # for epoch in [1]:# range(1, 2):
        if epoch > args.warm:
            train_scheduler.step(epoch)

        if args.resume:
            if epoch <= resume_epoch:
                continue

        train(epoch=epoch, **train_params)
        acc = eval_training(epoch=epoch, **train_params)

        #start to save best performance model after learning rate decay to 0.01
        if epoch > settings.MILESTONES[1] and best_acc < acc:
            torch.save(
                net.state_dict(),
                checkpoint_path.format(net=args.net, epoch=epoch, type='best'))
            best_acc = acc
            continue

        if not epoch % settings.SAVE_EPOCH:
            torch.save(
                net.state_dict(),
                checkpoint_path.format(net=args.net,
                                       epoch=epoch,
                                       type='regular'))

    writer.close()