Beispiel #1
0
    def prune(self):
        optimizer = torch.optim.SGD(self.model.parameters(),
                                    lr=self.args.lr,
                                    momentum=self.args.momentum,
                                    weight_decay=self.args.wd)
        if self.args.dataset == 'cifar10':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.args.epochs)
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.args.epochs, eta_min=0.0004)
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.lr_decay_step, gamma=0.1)

        start_epoch = 0
        best_top1_acc = 0
        best_top5_acc = 0
        # adjust the learning rate according to the checkpoint
        for epoch in range(start_epoch):
            scheduler.step()

        # train the model
        epoch = start_epoch
        while epoch < self.args.epochs:
            train_obj, train_top1_acc, train_top5_acc = self.train(
                epoch, self.train_loader, self.model, self.criterion,
                optimizer, scheduler)
            valid_obj, valid_top1_acc, valid_top5_acc = self.validate(
                epoch, self.val_loader, self.model, self.criterion)

            is_best = False
            if valid_top1_acc > best_top1_acc:
                best_top1_acc = valid_top1_acc
                is_best = True

            utils.save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': self.model.state_dict(),
                    'best_top1_acc': best_top1_acc,
                    'optimizer': optimizer.state_dict(),
                }, is_best, self.args.job_dir)

            epoch += 1
            self.logger.info("=>Best accuracy {:.3f}".format(best_top1_acc))  #
Beispiel #2
0
def make_checkpoint(model,
                    optimizer,
                    criterion,
                    epoch=None,
                    time_str=None,
                    args=None):
    fname = get_experiment_name(args)
    saved_path = save_checkpoint(model,
                                 optimizer,
                                 criterion,
                                 experiment_name=fname,
                                 epoch=epoch,
                                 time_str=time_str)

    print('Model saved to {}'.format(saved_path))
Beispiel #3
0
def run(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    logdir = args.logdir
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        # model
        model = Tacotron2(config)
        # optimizer
        optimizer = Tacotron2Optimizer(
            optim.Adam(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.l2,
                       betas=(0.9, 0.999),
                       eps=1e-6))

    else:
        start_epoch, epochs_since_improvement, model, optimizer, best_loss = load_checkpoint(
            logdir, checkpoint)

    logger = Logger(config.logdir, config.experiment, 'tacotron2')

    # Move to GPU, if available
    model = model.to(config.device)

    criterion = Tacotron2Loss()

    # Custom dataloaders
    train_dataset = Text2MelDataset(config.train_files, config)
    train_loader = Text2MelDataLoader(train_dataset,
                                      config,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
    valid_dataset = Text2MelDataset(config.valid_files, config)
    valid_loader = Text2MelDataLoader(valid_dataset,
                                      config,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, args.epochs):
        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           criterion=criterion,
                           epoch=epoch,
                           logger=logger)

        lr = optimizer.lr
        print('\nLearning rate: {}'.format(lr))
        step_num = optimizer.step_num
        print('Step num: {}\n'.format(step_num))

        scalar_dict = {'train_epoch_loss': train_loss, 'learning_rate': lr}
        logger.log_epoch('train', epoch, scalar_dict=scalar_dict)

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           criterion=criterion,
                           logger=logger)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: {}\n".format(
                epochs_since_improvement))
        else:
            epochs_since_improvement = 0

        scalar_dict = {'valid_epoch_loss': valid_loss}
        logger.log_epoch('valid', epoch, scalar_dict=scalar_dict)

        # Save checkpoint
        if epoch % args.save_freq == 0:
            save_checkpoint(logdir, epoch, epochs_since_improvement, model,
                            optimizer, best_loss, is_best)
Beispiel #4
0
def train(writer,
          loader_c,
          loader_sup,
          validation_loader,
          device,
          criterion,
          net,
          optimizer,
          lr_scheduler,
          num_epochs,
          is_mixed_precision,
          with_sup,
          num_classes,
          categories,
          input_sizes,
          val_num_steps=1000,
          loss_freq=10,
          tensorboard_prefix='',
          best_mIoU=0):
    #######
    # c for carry (pseudo labeled), sup for support (labeled with ground truth) -_-
    # Don't ask me why
    #######
    # Poly training schedule
    # Epoch length measured by "carry" (c) loader
    # Batch ratio is determined by loaders' own batch size
    # Validate and find the best snapshot per val_num_steps
    loss_num_steps = int(len(loader_c) / loss_freq)
    net.train()
    epoch = 0
    if with_sup:
        iter_sup = iter(loader_sup)

    if is_mixed_precision:
        scaler = GradScaler()

    # Training
    running_stats = {
        'disagree': -1,
        'current_win': -1,
        'avg_weights': 1.0,
        'loss': 0.0
    }
    while epoch < num_epochs:
        conf_mat = ConfusionMatrix(num_classes)
        time_now = time.time()
        for i, data in enumerate(loader_c, 0):
            # Combine loaders (maybe just alternate training will work)
            if with_sup:
                inputs_c, labels_c = data
                inputs_sup, labels_sup = next(iter_sup, (0, 0))
                if type(inputs_sup) == type(labels_sup) == int:
                    iter_sup = iter(loader_sup)
                    inputs_sup, labels_sup = next(iter_sup, (0, 0))

                # Formatting (prob: label + max confidence, label: just label)
                float_labels_sup = labels_sup.clone().float().unsqueeze(1)
                probs_sup = torch.cat(
                    [float_labels_sup,
                     torch.ones_like(float_labels_sup)],
                    dim=1)
                probs_c = labels_c.clone()
                labels_c = labels_c[:, 0, :, :].long()

                # Concatenating
                inputs = torch.cat([inputs_c, inputs_sup])
                labels = torch.cat([labels_c, labels_sup])
                probs = torch.cat([probs_c, probs_sup])

                probs = probs.to(device)
            else:
                inputs, labels = data

            # Normal training
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with autocast(is_mixed_precision):
                outputs = net(inputs)['out']
                outputs = torch.nn.functional.interpolate(outputs,
                                                          size=input_sizes[0],
                                                          mode='bilinear',
                                                          align_corners=True)
                conf_mat.update(labels.flatten(), outputs.argmax(1).flatten())

                if with_sup:
                    loss, stats = criterion(outputs, probs, inputs_c.shape[0])
                else:
                    loss, stats = criterion(outputs, labels)

            if is_mixed_precision:
                accelerator.backward(scaler.scale(loss))
                scaler.step(optimizer)
                scaler.update()
            else:
                accelerator.backward(loss)
                optimizer.step()

            lr_scheduler.step()

            # Logging
            for key in stats.keys():
                running_stats[key] += stats[key]
            current_step_num = int(epoch * len(loader_c) + i + 1)
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                for key in running_stats.keys():
                    print('[%d, %d] ' % (epoch + 1, i + 1) + key + ' : %.4f' %
                          (running_stats[key] / loss_num_steps))
                    writer.add_scalar(tensorboard_prefix + key,
                                      running_stats[key] / loss_num_steps,
                                      current_step_num)
                    running_stats[key] = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1) or \
                current_step_num == num_epochs * len(loader_c) - 1:
                # Apex bug https://github.com/NVIDIA/apex/issues/706, fixed in PyTorch1.6, kept here for BC
                test_pixel_accuracy, test_mIoU = test_one_set(
                    loader=validation_loader,
                    device=device,
                    net=net,
                    num_classes=num_classes,
                    categories=categories,
                    output_size=input_sizes[2],
                    is_mixed_precision=is_mixed_precision)
                writer.add_scalar(tensorboard_prefix + 'test pixel accuracy',
                                  test_pixel_accuracy, current_step_num)
                writer.add_scalar(tensorboard_prefix + 'test mIoU', test_mIoU,
                                  current_step_num)
                net.train()

                # Record best model(Straight to disk)
                if test_mIoU > best_mIoU:
                    best_mIoU = test_mIoU
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    is_mixed_precision=is_mixed_precision)

        # Evaluate training accuracies(same metric as validation, but must be on-the-fly to save time)
        acc_global, acc, iu = conf_mat.compute()
        print(categories)
        print(('global correct: {:.2f}\n'
               'average row correct: {}\n'
               'IoU: {}\n'
               'mean IoU: {:.2f}').format(
                   acc_global.item() * 100,
                   ['{:.2f}'.format(i) for i in (acc * 100).tolist()],
                   ['{:.2f}'.format(i) for i in (iu * 100).tolist()],
                   iu.mean().item() * 100))

        train_pixel_acc = acc_global.item() * 100
        train_mIoU = iu.mean().item() * 100
        writer.add_scalar(tensorboard_prefix + 'train pixel accuracy',
                          train_pixel_acc, epoch + 1)
        writer.add_scalar(tensorboard_prefix + 'train mIoU', train_mIoU,
                          epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    return best_mIoU
Beispiel #5
0
def train(writer,
          labeled_loader,
          pseudo_labeled_loader,
          val_loader,
          device,
          criterion,
          net,
          optimizer,
          lr_scheduler,
          num_epochs,
          tensorboard_prefix,
          gamma1,
          gamma2,
          labeled_weight,
          start_at,
          num_classes,
          decay=0.999,
          alpha=-1,
          is_mixed_precision=False,
          loss_freq=10,
          val_num_steps=None,
          best_acc=0,
          fine_grain=False):
    # Define validation and loss value print frequency
    # Pseudo labeled defines epoch
    min_len = len(pseudo_labeled_loader)
    if min_len > loss_freq:
        loss_num_steps = int(min_len / loss_freq)
    else:  # For extremely small sets
        loss_num_steps = min_len
    if val_num_steps is None:
        val_num_steps = min_len

    if is_mixed_precision:
        scaler = GradScaler()

    net.train()

    # Use EMA to report final performance instead of select best checkpoint with valtiny
    ema = EMA(net=net, decay=decay)

    epoch = 0

    # Training
    running_loss = 0.0
    running_stats = {
        'disagree': -1,
        'current_win': -1,
        'avg_weights': 1.0,
        'gamma1': 0,
        'gamma2': 0
    }
    iter_labeled = iter(labeled_loader)
    while epoch < num_epochs:
        train_correct = 0
        train_all = 0
        time_now = time.time()
        for i, data in enumerate(pseudo_labeled_loader, 0):
            # Pseudo labeled data
            inputs_pseudo, labels_pseudo = data
            inputs_pseudo, labels_pseudo = inputs_pseudo.to(
                device), labels_pseudo.to(device)

            # Hard labels
            probs_pseudo = labels_pseudo.clone().detach()
            labels_pseudo = labels_pseudo.argmax(-1)  # data type?

            # Labeled data
            inputs_labeled, labels_labeled = next(iter_labeled, (0, 0))
            if type(inputs_labeled) == type(labels_labeled) == int:
                iter_labeled = iter(labeled_loader)
                inputs_labeled, labels_labeled = next(iter_labeled, (0, 0))
            inputs_labeled, labels_labeled = inputs_labeled.to(
                device), labels_labeled.to(device)

            # To probabilities (in fact, just one-hot)
            probs_labeled = torch.nn.functional.one_hot(labels_labeled.clone().detach(), num_classes=num_classes) \
                .float()

            # Combine
            inputs = torch.cat([inputs_pseudo, inputs_labeled])
            labels = torch.cat([labels_pseudo, labels_labeled])
            probs = torch.cat([probs_pseudo, probs_labeled])
            optimizer.zero_grad()
            train_all += labels.shape[0]

            # mixup data within the batch
            if alpha != -1:
                dynamic_weights, stats = criterion.dynamic_weights_calc(
                    net=net,
                    inputs=inputs,
                    targets=probs,
                    split_index=inputs_pseudo.shape[0],
                    labeled_weight=labeled_weight)
                inputs, dynamic_weights, labels_a, labels_b, lam = mixup_data(
                    x=inputs,
                    w=dynamic_weights,
                    y=labels,
                    alpha=alpha,
                    keep_max=True)
            with autocast(is_mixed_precision):
                outputs = net(inputs)

            if alpha != -1:
                # Pseudo training accuracy & interesting loss
                predicted = outputs.argmax(1)
                train_correct += (
                    lam * (predicted == labels_a).sum().float().item() +
                    (1 - lam) * (predicted == labels_b).sum().float().item())
                loss, true_loss = criterion(pred=outputs,
                                            y_a=labels_a,
                                            y_b=labels_b,
                                            lam=lam,
                                            dynamic_weights=dynamic_weights)
            else:
                train_correct += (labels == outputs.argmax(1)).sum().item()
                loss, true_loss, stats = criterion(
                    inputs=outputs,
                    targets=probs,
                    split_index=inputs_pseudo.shape[0],
                    gamma1=gamma1,
                    gamma2=gamma2)

            if is_mixed_precision:
                accelerator.backward(scaler.scale(loss))
                scaler.step(optimizer)
                scaler.update()
            else:
                accelerator.backward(loss)
                optimizer.step()
            criterion.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            # EMA update
            ema.update(net=net)

            # Logging
            running_loss += true_loss
            for key in stats.keys():
                running_stats[key] += stats[key]
            current_step_num = int(epoch * len(pseudo_labeled_loader) + i + 1)
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                print('[%d, %d] loss: %.4f' %
                      (epoch + 1, i + 1, running_loss / loss_num_steps))
                writer.add_scalar(tensorboard_prefix + 'training loss',
                                  running_loss / loss_num_steps,
                                  current_step_num)
                running_loss = 0.0
                for key in stats.keys():
                    print('[%d, %d] ' % (epoch + 1, i + 1) + key + ' : %.4f' %
                          (running_stats[key] / loss_num_steps))
                    writer.add_scalar(tensorboard_prefix + key,
                                      running_stats[key] / loss_num_steps,
                                      current_step_num)
                    running_stats[key] = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1) or \
               current_step_num == num_epochs * len(pseudo_labeled_loader) - 1:
                # Apex bug https://github.com/NVIDIA/apex/issues/706, fixed in PyTorch1.6, kept here for BC
                test_acc = test(loader=val_loader,
                                device=device,
                                net=net,
                                fine_grain=fine_grain,
                                is_mixed_precision=is_mixed_precision)
                writer.add_scalar(tensorboard_prefix + 'test accuracy',
                                  test_acc, current_step_num)
                net.train()

                # Record best model(Straight to disk)
                if test_acc >= best_acc:
                    best_acc = test_acc
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    is_mixed_precision=is_mixed_precision)

        # Evaluate training accuracies (same metric as validation, but must be on-the-fly to save time)
        train_acc = train_correct / train_all * 100
        print('Train accuracy: %.4f' % train_acc)

        writer.add_scalar(tensorboard_prefix + 'train accuracy', train_acc,
                          epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    ema.fill_in_bn(state_dict=net.state_dict())
    save_checkpoint(net=ema,
                    optimizer=None,
                    lr_scheduler=None,
                    is_mixed_precision=False,
                    filename='temp-ema.pt')
    return best_acc
Beispiel #6
0
        if args.alpha == -1:
            criterion = DynamicMutualLoss()
        else:
            criterion = MixupDynamicMutualLoss(gamma1=args.gamma1,
                                               gamma2=args.gamma2,
                                               T_max=args.epochs *
                                               len(pseudo_labeled_loader))
        writer = SummaryWriter('logs/' + exp_name)

        best_acc = test(loader=val_loader,
                        device=device,
                        net=net,
                        fine_grain=args.fine_grain,
                        is_mixed_precision=args.mixed_precision)
        save_checkpoint(net=net,
                        optimizer=None,
                        lr_scheduler=None,
                        is_mixed_precision=args.mixed_precision)
        print('Original acc: ' + str(best_acc))

        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer,
            T_max=args.epochs * len(pseudo_labeled_loader))
        # lr_scheduler = None

        # Retraining (i.e. fine-tuning)
        best_acc = train(writer=writer,
                         labeled_loader=labeled_loader,
                         pseudo_labeled_loader=pseudo_labeled_loader,
                         val_loader=val_loader,
                         device=device,
                         criterion=criterion,
Beispiel #7
0
def run():
    model = net_factory.loader(opt.model, opt.num_classes)

    if opt.use_multi_gpu:
        model = nn.DataParallel(model)

    if opt.use_gpu:
        import torch.backends.cudnn as cudnn
        cudnn.benchmark = True
        model = model.cuda()
        criterion = opt.criterion.cuda()
    else:
        criterion = opt.criterion

    best_metric = 0
    best_loss = 10000

    if opt.try_resume:
        common.resume(model, opt.resumed_check, opt.model)

    transformed_dataset_train = nuclei_dataset.NucleiDataset(
        root_dir=opt.train_data_root,
        mode='train',
        transform=opt.transforms['train'])
    train_loader = data.DataLoader(transformed_dataset_train,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=opt.num_workers,
                                   pin_memory=opt.use_gpu)

    transformed_dataset_val = nuclei_dataset.NucleiDataset(
        root_dir=opt.val_data_root,
        mode='val',
        transform=opt.transforms['val'])
    val_loader = data.DataLoader(transformed_dataset_val,
                                 batch_size=opt.batch_size,
                                 shuffle=False,
                                 num_workers=opt.num_workers,
                                 pin_memory=opt.use_gpu)
    print(len(train_loader) * opt.batch_size)
    print(len(val_loader) * opt.batch_size)

    if opt.optim_type == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)
    elif opt.optim_type == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.lr,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    lr_scheduler = lrs.ReduceLROnPlateau(optimizer,
                                         mode='min',
                                         factor=0.5,
                                         patience=2,
                                         verbose=True,
                                         threshold=0.0001,
                                         threshold_mode='rel',
                                         cooldown=0,
                                         min_lr=1e-6,
                                         eps=1e-08)

    for epoch in range(opt.epochs):

        # train for one epoch
        metric, loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        metric, loss = validate(val_loader, model, criterion, epoch)

        # remember best
        is_best = metric >= best_metric
        best_metric = max(metric, best_metric)
        if epoch % opt.save_freq == 0:
            common.save_checkpoint_epoch(
                {
                    'epoch': epoch,
                    'arch': opt.model,
                    'state_dict': model.state_dict(),
                    'best_metric': best_metric,
                    'loss': loss
                }, epoch, opt.model)
        if is_best:
            common.save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': opt.model,
                    'state_dict': model.state_dict(),
                    'best_metric': best_metric,
                    'loss': loss
                }, is_best, opt.model)

        lr_scheduler.step(loss)
Beispiel #8
0
def train(writer,
          train_loader,
          val_loader,
          device,
          criterion,
          net,
          optimizer,
          lr_scheduler,
          num_epochs,
          log_file,
          alpha=None,
          is_mixed_precision=False,
          loss_freq=10,
          val_num_steps=None,
          best_acc=0,
          fine_grain=False,
          decay=0.999):
    # Define validation and loss value print frequency
    if len(train_loader) > loss_freq:
        loss_num_steps = int(len(train_loader) / loss_freq)
    else:  # For extremely small sets
        loss_num_steps = len(train_loader)
    if val_num_steps is None:
        val_num_steps = len(train_loader)

    net.train()

    # Use EMA to report final performance instead of select best checkpoint with valtiny
    ema = EMA(net=net, decay=decay)

    epoch = 0

    # Training
    running_loss = 0.0
    while epoch < num_epochs:
        train_correct = 0
        train_all = 0
        time_now = time.time()
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            train_all += labels.shape[0]

            # mixup data within the batch
            if alpha is not None:
                inputs, labels_a, labels_b, lam = mixup_data(x=inputs,
                                                             y=labels,
                                                             alpha=alpha)

            outputs = net(inputs)

            if alpha is not None:
                # Pseudo training accuracy & interesting loss
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b,
                                       lam)
                predicted = outputs.argmax(1)
                train_correct += (
                    lam * (predicted == labels_a).sum().float().item() +
                    (1 - lam) * (predicted == labels_b).sum().float().item())
            else:
                train_correct += (labels == outputs.argmax(1)).sum().item()
                loss = criterion(outputs, labels)

            if is_mixed_precision:
                # 2/3 & 3/3 of mixed precision training with amp
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            # EMA update
            ema.update(net=net)

            # Logging
            running_loss += loss.item()
            current_step_num = int(epoch * len(train_loader) + i + 1)
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                print('[%d, %d] loss: %.4f' %
                      (epoch + 1, i + 1, running_loss / loss_num_steps))
                writer.add_scalar('training loss',
                                  running_loss / loss_num_steps,
                                  current_step_num)
                running_loss = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1) or \
               current_step_num == num_epochs * len(train_loader) - 1:
                # A bug in Apex? https://github.com/NVIDIA/apex/issues/706
                test_acc = test(loader=val_loader,
                                device=device,
                                net=net,
                                fine_grain=fine_grain)
                writer.add_scalar('test accuracy', test_acc, current_step_num)
                net.train()

                # Record best model(Straight to disk)
                if test_acc > best_acc:
                    best_acc = test_acc
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    is_mixed_precision=is_mixed_precision,
                                    filename=log_file + '_temp.pt')

        # Evaluate training accuracies (same metric as validation, but must be on-the-fly to save time)
        train_acc = train_correct / train_all * 100
        print('Train accuracy: %.4f' % train_acc)

        writer.add_scalar('train accuracy', train_acc, epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    ema.fill_in_bn(state_dict=net.state_dict())
    save_checkpoint(net=ema,
                    optimizer=None,
                    lr_scheduler=None,
                    is_mixed_precision=False,
                    filename=log_file + '_temp-ema.pt')
    return best_acc
Beispiel #9
0
        city_shape_match += 1
        my_city[my_key] = coco[key]
    else:
        city_shape_not_match += 1

print(str(voc_shape_match) + ' pascal voc shapes matched!')
print(str(voc_shape_not_match) + ' pascal voc shapes are not a match.')
print(str(city_shape_match) + ' cityscapes shapes matched!')
print(str(city_shape_not_match) + ' cityscapes shapes are not a match.')
print('Saving models...')

voc_net.load_state_dict(my_voc)
city_net.load_state_dict(my_city)
save_checkpoint(net=voc_net,
                optimizer=None,
                lr_scheduler=None,
                is_mixed_precision=False,
                filename='voc_coco_resnet101.pt')
save_checkpoint(net=city_net,
                optimizer=None,
                lr_scheduler=None,
                is_mixed_precision=False,
                filename='city_coco_resnet101.pt')
print('Complete.')

# Outputs should be the following after a few seconds:
# 528 pascal voc shapes matched!
# 0 pascal voc shapes are not a match.
# 520 cityscapes shapes matched!
# 8 cityscapes shapes are not a match.
# Saving models...
Beispiel #10
0
    def test(self, is_teacher=False):
        torch.set_grad_enabled(False)
        epoch = self.epoch
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(
            torch.zeros(1, len(self.loader_test), len(self.scale))
        )
        if is_teacher:
            model = self.t_model
        else:
            model = self.s_model
        model.eval()
        timer_test = utility.timer()
        
        if self.args.save_results: self.ckp.begin_background()
        for idx_data, d in enumerate(self.loader_test):
            for idx_scale, scale in enumerate(self.scale):
                d.dataset.set_scale(idx_scale)
                i = 0
                for lr, hr, filename, _ in tqdm(d, ncols=80):
                    i += 1
                    lr, hr = self.prepare(lr, hr)
                    sr, s_res = model(lr)
                    sr = utility.quantize(sr, self.args.rgb_range)
                    save_list = [sr]
                    cur_psnr = utility.calc_psnr(
                        sr, hr, scale, self.args.rgb_range, dataset=d
                    )
                    self.ckp.log[-1, idx_data, idx_scale] += cur_psnr
                    if self.args.save_gt:
                        save_list.extend([lr, hr])

                    if self.args.save_results:
                        save_name = f'{args.k_bits}bit_{filename[0]}'
                        self.ckp.save_results(d, save_name, save_list, scale)

                self.ckp.log[-1, idx_data, idx_scale] /= len(d)
                best = self.ckp.log.max(0)

                self.ckp.write_log(
                    '[{} x{}] PSNR: {:.3f}  (Best: {:.3f} @epoch {})'.format(
                        d.dataset.name,
                        scale,
                        self.ckp.log[-1, idx_data, idx_scale],
                        best[0][idx_data, idx_scale],
                        best[1][idx_data, idx_scale] + 1
                    )
                )
                self.writer_train.add_scalar(f'psnr', self.ckp.log[-1, idx_data, idx_scale], self.epoch)

        if self.args.save_results:
            self.ckp.end_background()
            
        if not self.args.test_only:
            is_best = (best[1][0, 0] + 1 == epoch)

            state = {
            'epoch': epoch,
            'state_dict': self.s_model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.sheduler.state_dict()
        }
            util.save_checkpoint(state, is_best, checkpoint =self.ckp.dir + '/model')
        
        self.ckp.write_log(
            'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
        )

        torch.set_grad_enabled(True)
Beispiel #11
0
def main(args):
    check_params(args)

    print("Init models...")

    G = Generator(args.dataset).cuda()
    D = Discriminator(args).cuda()

    loss_tracker = LossSummary()

    loss_fn = AnimeGanLoss(args)

    # Create DataLoader
    data_loader = DataLoader(
        AnimeDataSet(args),
        batch_size=args.batch_size,
        num_workers=cpu_count(),
        pin_memory=True,
        shuffle=True,
        collate_fn=collate_fn,
    )

    optimizer_g = optim.Adam(G.parameters(), lr=args.lr_g, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(D.parameters(), lr=args.lr_d, betas=(0.5, 0.999))

    start_e = 0
    if args.resume == 'GD':
        # Load G and D
        try:
            start_e = load_checkpoint(G, args.checkpoint_dir)
            print("G weight loaded")
            load_checkpoint(D, args.checkpoint_dir)
            print("D weight loaded")
        except Exception as e:
            print('Could not load checkpoint, train from scratch', e)
    elif args.resume == 'G':
        # Load G only
        try:
            start_e = load_checkpoint(G, args.checkpoint_dir, posfix='_init')
        except Exception as e:
            print('Could not load G init checkpoint, train from scratch', e)

    for e in range(start_e, args.epochs):
        print(f"Epoch {e}/{args.epochs}")
        bar = tqdm(data_loader)
        G.train()

        init_losses = []

        if e < args.init_epochs:
            # Train with content loss only
            set_lr(optimizer_g, args.init_lr)
            for img, *_ in bar:
                img = img.cuda()

                optimizer_g.zero_grad()

                fake_img = G(img)
                loss = loss_fn.content_loss_vgg(img, fake_img)
                loss.backward()
                optimizer_g.step()

                init_losses.append(loss.cpu().detach().numpy())
                avg_content_loss = sum(init_losses) / len(init_losses)
                bar.set_description(
                    f'[Init Training G] content loss: {avg_content_loss:2f}')

            set_lr(optimizer_g, args.lr_g)
            save_checkpoint(G, optimizer_g, e, args, posfix='_init')
            save_samples(G, data_loader, args, subname='initg')
            continue

        loss_tracker.reset()
        for img, anime, anime_gray, anime_smt_gray in bar:
            # To cuda
            img = img.cuda()
            anime = anime.cuda()
            anime_gray = anime_gray.cuda()
            anime_smt_gray = anime_smt_gray.cuda()

            # ---------------- TRAIN D ---------------- #
            optimizer_d.zero_grad()
            fake_img = G(img).detach()

            # Add some Gaussian noise to images before feeding to D
            if args.d_noise:
                fake_img += gaussian_noise()
                anime += gaussian_noise()
                anime_gray += gaussian_noise()
                anime_smt_gray += gaussian_noise()

            fake_d = D(fake_img)
            real_anime_d = D(anime)
            real_anime_gray_d = D(anime_gray)
            real_anime_smt_gray_d = D(anime_smt_gray)

            loss_d = loss_fn.compute_loss_D(fake_d, real_anime_d,
                                            real_anime_gray_d,
                                            real_anime_smt_gray_d)

            loss_d.backward()
            optimizer_d.step()

            loss_tracker.update_loss_D(loss_d)

            # ---------------- TRAIN G ---------------- #
            optimizer_g.zero_grad()

            fake_img = G(img)
            fake_d = D(fake_img)

            adv_loss, con_loss, gra_loss, col_loss = loss_fn.compute_loss_G(
                fake_img, img, fake_d, anime_gray)

            loss_g = adv_loss + con_loss + gra_loss + col_loss

            loss_g.backward()
            optimizer_g.step()

            loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss)

            avg_adv, avg_gram, avg_color, avg_content = loss_tracker.avg_loss_G(
            )
            avg_adv_d = loss_tracker.avg_loss_D()
            bar.set_description(
                f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}'
            )

        if e % args.save_interval == 0:
            save_checkpoint(G, optimizer_g, e, args)
            save_checkpoint(D, optimizer_d, e, args)
            save_samples(G, data_loader, args)