Exemple #1
0
def trian_validate_with_scheduling(args,
                                   net,
                                   criterion,
                                   optimizer,
                                   compress_scheduler,
                                   device,
                                   epoch=1,
                                   validate=True,
                                   verbose=True):
    # Whtat's collectors_context
    if compress_scheduler:
        compress_scheduler.on_epoch_begin(epoch)

    top1, top5, loss = light_train_with_distiller(net, criterion, optimizer,
                                                  compress_scheduler, device,
                                                  epoch)

    if validate:
        """
        top1, loss = _validate(net, criterion, optimizer, lr_scheduler, compress_scheduler,
                                device, epoch) # remove top5 accuracy.
        """
        top1, top5, loss = _validate('val', net, criterion, device)
    #print(summary.masks_sparsity_tbl_summary(net, compress_scheduler))
    t, total = summary.weights_sparsity_tbl_summary(net,
                                                    return_total_sparsity=True)
    print("\nParameters:\n" + str(t))
    print('Total sparsity: {:0.2f}\n'.format(total))

    if compress_scheduler:
        compress_scheduler.on_epoch_end(epoch,
                                        optimizer,
                                        metrics={
                                            'min': loss,
                                            'max': top1
                                        })

    # Build performance tracker object whilst saving it.
    tracker = pt.SparsityAccuracyTracker(args.num_best_scores)
    tracker.step(net, epoch, top1=top1, top5=top5)  #, top5=top5)
    best_score = tracker.best_scores()[0]
    is_best = epoch == best_score.epoch
    checkpoint_extras = {
        'current_top1': top1,
        'best_top1': best_score.top1,
        'best_epoch': best_score.epoch
    }

    # args.arch = Architecture name
    ckpt.save_checkpoint(args.epoch,
                         args.arch,
                         net,
                         optimizer=optimizer,
                         scheduler=compress_scheduler,
                         extras=checkpoint_extras,
                         is_best=is_best,
                         name=args.name,
                         dir=args.model_path)
    return top1, top5, loss, tracker
Exemple #2
0
def quantize_and_test_model(test_loader,
                            model,
                            criterion,
                            args,
                            scheduler=None,
                            save_flag=True):
    """Collect stats using test_loader (when stats file is absent),
    clone the model and quantize the clone, and finally, test it.
    args.device is allowed to differ from the model's device.
    When args.qe_calibration is set to None, uses 0.05 instead.
    scheduler - pass scheduler to store it in checkpoint
    save_flag - defaults to save both quantization statistics and checkpoint.
    """
    if hasattr(model, 'quantizer_metadata') and \
            model.quantizer_metadata['type'] == distiller.quantization.PostTrainLinearQuantizer:
        raise RuntimeError(
            'Trying to invoke post-training quantization on a model that has already been post-'
            'train quantized. Model was likely loaded from a checkpoint. Please run again without '
            'passing the --quantize-eval flag')
    if not (args.qe_dynamic or args.qe_stats_file or args.qe_config_file):

        args_copy = copy.deepcopy(args)
        args_copy.qe_calibration = args.qe_calibration if args.qe_calibration is not None else 0.05

        # set stats into args stats field
        args.qe_stats_file = acts_quant_stats_collection(
            model, criterion, loggers, args_copy, save_to_file=save_flag)

    args_qe = copy.deepcopy(args)
    if args.device == 'cpu':
        # NOTE: Even though args.device is CPU, we allow here that model is not in CPU.
        qe_model = distiller.make_non_parallel_copy(model).cpu()
    else:
        qe_model = copy.deepcopy(model).to(args.device)

    quantizer = quantization.PostTrainLinearQuantizer.from_args(
        qe_model, args_qe)
    dummy_input = utl.get_dummy_input(
        input_shape=(1, 3, 224, 224))  # should modifiled! or add to args
    quantizer.prepare_model(dummy_input)

    if args.qe_convert_pytorch:
        qe_model = _convert_ptq_to_pytorch(qe_model, args_qe)
    # should check device
    test_res = test(qe_model, criterion, args.device)

    if save_flag:
        checkpoint_name = 'quantized'
        ckpt.save_checkpoint(0,
                             args_qe.arch,
                             qe_model,
                             scheduler=scheduler,
                             name='_'.join([args_qe.name, checkpoint_name])
                             if args_qe.name else checkpoint_name,
                             dir=args.model_path,
                             extras={'quantized_top1': test_res[0]})

    del qe_model
    return test_res
Exemple #3
0
def train(config, dataset, model):
    # Data loaders
    train_loader, val_loader = dataset.train_loader, dataset.val_loader

    if 'use_weighted' not in config:
        # TODO (part c): define loss function
        criterion = None
    else:
        # TODO (part e): define weighted loss function
        criterion = None
    # TODO (part c): define optimizer
    learning_rate = config['learning_rate']
    optimizer = None

    # Attempts to restore the latest checkpoint if exists
    print('Loading model...')
    force = config['ckpt_force'] if 'ckpt_force' in config else False
    model, start_epoch, stats = checkpoint.restore_checkpoint(
        model, config['ckpt_path'], force=force)

    # Create plotter
    plot_name = config['plot_name'] if 'plot_name' in config else 'CNN'
    plotter = Plotter(stats, plot_name)

    # Evaluate the model
    _evaluate_epoch(plotter, train_loader, val_loader, model, criterion,
                    start_epoch)

    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config['num_epoch']):
        # Train model on training set
        _train_epoch(train_loader, model, criterion, optimizer)

        # Evaluate model on training and validation set
        _evaluate_epoch(plotter, train_loader, val_loader, model, criterion,
                        epoch + 1)

        # Save model parameters
        checkpoint.save_checkpoint(model, epoch + 1, config['ckpt_path'],
                                   plotter.stats)

    print('Finished Training')

    # Save figure and keep plot open
    plotter.save_cnn_training_plot()
    plotter.hold_training_plot()
Exemple #4
0
def train(checkpoint_path):
    # 是否装载模型参数
    load = False

    if load:
        checkpoint = cp.load_checkpoint(address=checkpoint_path)
        net.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch'] + 1
    else:
        start_epoch = 0

    for epoch in range(start_epoch, n_epoch):
        train_one_epoch()

        # 保存参数
        checkpoint = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        cp.save_checkpoint(checkpoint, address=checkpoint_path)

        eval()
Exemple #5
0
    def train(self, checkpoint_path):
        # 是否装载模型参数
        load = False

        if load:
            checkpoint = cp.load_checkpoint(address=checkpoint_path)
            self.model.load_state_dict(checkpoint['state_dict'])
            start_epoch = checkpoint['epoch'] + 1
        else:
            start_epoch = 0

        for epoch in range(start_epoch, self.n_epoch):
            self.train_one_epoch(epoch)

            # 保存参数
            checkpoint = {
                'epoch': epoch,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict()
            }
            cp.save_checkpoint(checkpoint, address=checkpoint_path)

            if self.selftest:
                self.eval(epoch)
def train(config,
          dataset,
          model,
          save_weights_iter=None,
          params_to_save=None,
          save_progress=True):

    lr = config['lr']
    beta1 = config['beta1']
    device = config['device']
    val_freq = config['val_freq']
    batch_size = config['batch_size']
    opt = config['opt']
    scheduler = config['scheduler']

    crit = config['crit']

    train_loader, val_loader, test_loader = dataset.train_loader, dataset.val_loader, dataset.test_loader

    force = config['ckpt_force'] if 'ckpt_force' in config else False

    if save_progress:
        model, opt, scheduler, start_epoch, stats = checkpoint.restore_checkpoint(
            model,
            opt,
            scheduler,
            config['ckpt_path'],
            force=force,
            pretrain=False)
    else:
        start_epoch = 0

    if not stats:
        model_loss = []
        model_acc = []
        iterations = []
    else:
        model_loss, model_acc, iterations = stats
    itr = 0

    train_loss = []

    for epoch in range(start_epoch, config['num_epoch']):
        model.train()
        epoch_train_loss = []
        for i, (X, y) in enumerate(train_loader):
            if (save_weights_iter
                    or save_weights_iter == 0) and itr == save_weights_iter:
                checkpoint_weights(model, params_to_save)
            X = X.to(device)
            y = y.to(device)
            opt.zero_grad()
            output = model(X).to(device)
            loss = crit(output, y)
            loss.backward()
            opt.step()

            train_loss.append(loss.item())
            epoch_train_loss.append(loss.item())
            acc, loss = 0, 0
            if itr % val_freq == 0:
                acc, loss = test(model, val_loader, crit, device)
                model_acc.append(acc)
                model_loss.append(loss)
                iterations.append(itr)

                print(itr, epoch, acc, loss, avg(epoch_train_loss))

            itr += 1

        stats = [model_loss, model_acc, iterations]
        if save_progress:
            checkpoint.save_checkpoint(model, opt, scheduler, epoch + 1,
                                       config['ckpt_path'], stats)
        if scheduler:
            scheduler.step()

    test_acc, test_loss = test(model, test_loader, crit, device)
    return [
        model, model_loss, model_acc, iterations, test_loss, test_acc,
        train_loss
    ]
Exemple #7
0
def train(**args):
    """
    Evaluate selected model 
    Args:
        rerun        (Int):        Integer indicating number of repetitions for the select experiment 
        seed         (Int):        Integer indicating set seed for random state
        save_dir     (String):     Top level directory to generate results folder
        model        (String):     Name of selected model 
        dataset      (String):     Name of selected dataset  
        exp          (String):     Name of experiment 
        debug        (Int):        Debug state to avoid saving variables 
        load_type    (String):     Keyword indicator to evaluate the testing or validation set
        pretrained   (Int/String): Int/String indicating loading of random, pretrained or saved weights
        opt          (String):     Int/String indicating loading of random, pretrained or saved weights
        lr           (Float):      Learning rate 
        momentum     (Float):      Momentum in optimizer 
        weight_decay (Float):      Weight_decay value 
        final_shape  ([Int, Int]): Shape of data when passed into network
        
    Return:
        None
    """

    print(
        "\n############################################################################\n"
    )
    print("Experimental Setup: ", args)
    print(
        "\n############################################################################\n"
    )

    for total_iteration in range(args['rerun']):

        # Generate Results Directory
        d = datetime.datetime.today()
        date = d.strftime('%Y%m%d-%H%M%S')
        result_dir = os.path.join(
            args['save_dir'], args['model'], '_'.join(
                (args['dataset'], args['exp'], date)))
        log_dir = os.path.join(result_dir, 'logs')
        save_dir = os.path.join(result_dir, 'checkpoints')

        if not args['debug']:
            os.makedirs(result_dir, exist_ok=True)
            os.makedirs(log_dir, exist_ok=True)
            os.makedirs(save_dir, exist_ok=True)

            # Save copy of config file
            with open(os.path.join(result_dir, 'config.yaml'), 'w') as outfile:
                yaml.dump(args, outfile, default_flow_style=False)

            # Tensorboard Element
            writer = SummaryWriter(log_dir)

        # Check if GPU is available (CUDA)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Load Network
        model = create_model_object(**args).to(device)

        # Load Data
        loader = data_loader(model_obj=model, **args)

        if args['load_type'] == 'train':
            train_loader = loader['train']
            valid_loader = loader[
                'train']  # Run accuracy on train data if only `train` selected

        elif args['load_type'] == 'train_val':
            train_loader = loader['train']
            valid_loader = loader['valid']

        else:
            sys.exit('Invalid environment selection for training, exiting')

        # END IF

        # Training Setup
        params = [p for p in model.parameters() if p.requires_grad]

        if args['opt'] == 'sgd':
            optimizer = optim.SGD(params,
                                  lr=args['lr'],
                                  momentum=args['momentum'],
                                  weight_decay=args['weight_decay'])

        elif args['opt'] == 'adam':
            optimizer = optim.Adam(params,
                                   lr=args['lr'],
                                   weight_decay=args['weight_decay'])

        else:
            sys.exit('Unsupported optimizer selected. Exiting')

        # END IF

        scheduler = MultiStepLR(optimizer,
                                milestones=args['milestones'],
                                gamma=args['gamma'])

        if isinstance(args['pretrained'], str):
            ckpt = load_checkpoint(args['pretrained'])
            model.load_state_dict(ckpt)
            start_epoch = load_checkpoint(args['pretrained'],
                                          key_name='epoch') + 1
            optimizer.load_state_dict(
                load_checkpoint(args['pretrained'], key_name='optimizer'))

            for quick_looper in range(start_epoch):
                scheduler.step()

            # END FOR

        else:
            start_epoch = 0

        # END IF

        model_loss = Losses(device=device, **args)
        acc_metric = Metrics(**args)
        best_val_acc = 0.0

        ############################################################################################################################################################################

        # Start: Training Loop
        for epoch in range(start_epoch, args['epoch']):
            running_loss = 0.0
            print('Epoch: ', epoch)

            # Setup Model To Train
            model.train()

            # Start: Epoch
            for step, data in enumerate(train_loader):
                if step % args['pseudo_batch_loop'] == 0:
                    loss = 0.0
                    optimizer.zero_grad()

                # END IF

                x_input = data['data'].to(device)
                annotations = data['annots']

                assert args['final_shape'] == list(x_input.size(
                )[-2:]), "Input to model does not match final_shape argument"
                outputs = model(x_input)
                loss = model_loss.loss(outputs, annotations)
                loss = loss * args['batch_size']
                loss.backward()

                running_loss += loss.item()

                if np.isnan(running_loss):
                    import pdb
                    pdb.set_trace()

                # END IF

                if not args['debug']:
                    # Add Learning Rate Element
                    for param_group in optimizer.param_groups:
                        writer.add_scalar(
                            args['dataset'] + '/' + args['model'] +
                            '/learning_rate', param_group['lr'],
                            epoch * len(train_loader) + step)

                    # END FOR

                    # Add Loss Element
                    writer.add_scalar(
                        args['dataset'] + '/' + args['model'] +
                        '/minibatch_loss',
                        loss.item() / args['batch_size'],
                        epoch * len(train_loader) + step)

                # END IF

                if ((epoch * len(train_loader) + step + 1) % 100 == 0):
                    print('Epoch: {}/{}, step: {}/{} | train loss: {:.4f}'.
                          format(
                              epoch, args['epoch'], step + 1,
                              len(train_loader), running_loss /
                              float(step + 1) / args['batch_size']))

                # END IF

                if (epoch * len(train_loader) +
                    (step + 1)) % args['pseudo_batch_loop'] == 0 and step > 0:
                    # Apply large mini-batch normalization
                    for param in model.parameters():
                        param.grad *= 1. / float(
                            args['pseudo_batch_loop'] * args['batch_size'])
                    optimizer.step()

                # END IF

            # END FOR: Epoch

            if not args['debug']:
                # Save Current Model
                save_path = os.path.join(
                    save_dir, args['dataset'] + '_epoch' + str(epoch) + '.pkl')
                save_checkpoint(epoch, step, model, optimizer, save_path)

            # END IF: Debug

            scheduler.step(epoch=epoch)
            print('Schedulers lr: %f', scheduler.get_lr()[0])

            ## START FOR: Validation Accuracy
            running_acc = []
            running_acc = valid(valid_loader, running_acc, model, device,
                                acc_metric)
            if not args['debug']:
                writer.add_scalar(
                    args['dataset'] + '/' + args['model'] +
                    '/validation_accuracy', 100. * running_acc[-1],
                    epoch * len(valid_loader) + step)
            print('Accuracy of the network on the validation set: %f %%\n' %
                  (100. * running_acc[-1]))

            # Save Best Validation Accuracy Model Separately
            if best_val_acc < running_acc[-1]:
                best_val_acc = running_acc[-1]

                if not args['debug']:
                    # Save Current Model
                    save_path = os.path.join(
                        save_dir, args['dataset'] + '_best_model.pkl')
                    save_checkpoint(epoch, step, model, optimizer, save_path)

                # END IF

            # END IF

        # END FOR: Training Loop

    ############################################################################################################################################################################

        if not args['debug']:
            # Close Tensorboard Element
            writer.close()
    else:
        architecture = 'vgg16'

    if (args.checkpoint_name):
        checkpoint_name = args.checkpoint_name
    else:
        checkpoint_name = 'ic-model.pth'

    if (args.root_dir):
        root_dir = args.root_dir
    else:
        root_dir = '/'

    model = train_model(image_datasets=image_datasets,
                        dataloaders=dataloaders,
                        dataset_sizes=dataset_sizes,
                        arch=architecture,
                        hidden_units=hidden_units,
                        num_epochs=eps,
                        learning_rate=learning_rate,
                        device=device)

    print(model)
    class_to_idx = image_datasets['train'].class_to_idx

    loader.save_checkpoint(model=model,
                           checkpoint_name=checkpoint_name,
                           arch=architecture,
                           hidden_units=hidden_units,
                           class_to_idx=class_to_idx,
                           learning_rate=learning_rate)
    def train_(self, criterion, logger, model, **pars):
        """
        grad_cache will be updated in-place
        *** only learning_rate and current_loader_ind need to be loaded at checkpoint
        K: the number of active clients
        """
        default_pars = dict(learning_rate=1e-2,
                            K=10,
                            num_its=5000,
                            lr_decay=0.5,
                            decay_step_size=1000,
                            print_every=50,
                            checkpoint_interval=1000)
        init_pars(default_pars, pars)
        pars = default_pars

        K = pars['K']
        learning_rate = pars['learning_rate']
        num_its = pars['num_its']
        lr_decay = pars['lr_decay']
        decay_step_size = pars['decay_step_size']
        print_every = pars['print_every']
        checkpoint_interval = pars['checkpoint_interval']

        I = self.pars['I']
        N = self.pars['N']

        checkpoint_dir = self.checkpoint_dir

        logger.add_meta_data(pars, 'training')
        logger.add_meta_data(self.pars, 'simulation')

        if use_cuda:
            model = model.to(torch.device('cuda'))
        else:
            model = model.to(torch.device('cpu'))

        if osp.exists(osp.join(checkpoint_dir, 'meta.pkl')):
            current_it = load_checkpoint(checkpoint_dir, model, logger)
        else:
            current_it = 0

        while True:
            current_lr = learning_rate * (lr_decay
                                          **(current_it // decay_step_size))
            print(f"current_it={current_it}, current_lr={current_lr}",
                  end='\r')

            global_model = deepcopy(model)
            zero_model(global_model)

            # set the number of  active clients
            idxs_users = np.random.choice(range(N), K, replace=False)
            for idx in idxs_users:
                worker = self.workers[idx]
                local_model = deepcopy(model)
                worker.train_(local_model,
                              criterion,
                              current_lr=current_lr,
                              num_its=I)
                aggregate_model(
                    global_model, local_model, 1,
                    N / K * (worker.num_train / self.num_total_samples))
            model = global_model
            logger.add_train_loss(
                list(model.parameters())[0][0][0][0][0], current_it,
                'model-par')

            if current_it % print_every == 0:
                # fedavg
                fed_acc_array = self.test_model(model)
                fed_acc = np.array(fed_acc_array).mean()
                print('%d fedavg test acc: %.3f%%' %
                      (current_it, fed_acc * 100.0))
                logger.add_test_acc(fed_acc, current_it, 'fedavg')

            if current_it % checkpoint_interval == 0:
                save_checkpoint(current_it, model, logger, checkpoint_dir)

            if current_it == num_its:
                print('Finished Training')
                return

            current_it += 1
        print('*' * 30)
    else:
        start_epoch = 0
        # model = init_weights(model)

    for epoch in range(start_epoch, n_epochs):
        print('Epoch: %d/%d' % (epoch + 1, n_epochs))
        train_loss = train_model(model,
                                 train_loader,
                                 criterion,
                                 optimizer,
                                 device,
                                 measure_accuracy=True,
                                 opti_batch=args.opti_batch)
        val_loss, val_acc = val_model(model, test_loader, criterion, device)
        # Checkpoint the model after each epoch.
        save_checkpoint(epoch + 1,
                        args.model_path,
                        model=model,
                        optimizer=optimizer,
                        val_metric=val_acc)
        if args.fixed_lr_decay:
            if epoch in args.lr_decay_epochs:
                cur_lr = get_current_lr(optimizer)
                optimizer = set_current_lr(optimizer, cur_lr * 0.1)
                print('Epoch    %d: reducing learning rate of group 0 to %f' %
                      (epoch, cur_lr * 0.1))
        else:
            scheduler.step(val_loss)
        print('=' * 20)
Exemple #11
0
def train(
    logger: lavd.Logger,
    model: nn.Module,
    optimiser: optim.Optimizer,  # type: ignore
    train_data_loader: DataLoader,
    validation_data_loaders: DataLoader,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    device: torch.device,
    checkpoint: Dict,
    num_epochs: int = num_epochs,
    model_kind: str = default_model,
    amp_scaler: Optional[amp.GradScaler] = None,
    masked_lm: bool = True,
):
    start_epoch = checkpoint["epoch"]
    train_stats = checkpoint["train"]
    validation_cp = checkpoint["validation"]
    outdated_validations = checkpoint["outdated_validation"]

    validation_results_dict: Dict[str, Dict] = OrderedDict()
    for val_data_loader in validation_data_loaders:
        val_name = val_data_loader.dataset.name
        val_result = (validation_cp[val_name] if val_name in validation_cp else
                      OrderedDict(start=start_epoch,
                                  stats=OrderedDict(loss=[], perplexity=[])))
        validation_results_dict[val_name] = val_result

    # All validations that are no longer used, will be stored in outdated_validation
    # just to have them available.
    outdated_validations.append(
        OrderedDict({
            k: v
            for k, v in validation_cp.items()
            if k not in validation_results_dict
        }))

    tokeniser = train_data_loader.dataset.tokeniser  # type: ignore
    for epoch in range(num_epochs):
        actual_epoch = start_epoch + epoch + 1
        epoch_text = "[{current:>{pad}}/{end}] Epoch {epoch}".format(
            current=epoch + 1,
            end=num_epochs,
            epoch=actual_epoch,
            pad=len(str(num_epochs)),
        )
        logger.set_prefix(epoch_text)
        logger.start(epoch_text, prefix=False)
        start_time = time.time()

        logger.start("Train")
        train_result = run_epoch(
            train_data_loader,
            model,
            optimiser,
            device=device,
            epoch=epoch,
            train=True,
            name="Train",
            logger=logger,
            amp_scaler=amp_scaler,
            masked_lm=masked_lm,
        )
        train_stats["stats"]["loss"].append(train_result["loss"])
        train_stats["stats"]["perplexity"].append(train_result["perplexity"])
        epoch_lr = lr_scheduler.get_last_lr()[0]  # type: ignore
        train_stats["lr"].append(epoch_lr)
        lr_scheduler.step()
        logger.end("Train")

        validation_results = []
        for val_data_loader in validation_data_loaders:
            val_name = val_data_loader.dataset.name
            val_text = "Validation: {}".format(val_name)
            logger.start(val_text)
            validation_result = run_epoch(
                val_data_loader,
                model,
                optimiser,
                device=device,
                epoch=epoch,
                train=False,
                name=val_text,
                logger=logger,
                amp_scaler=amp_scaler,
                masked_lm=masked_lm,
            )
            validation_results.append(
                OrderedDict(name=val_name, stats=validation_result))
            validation_results_dict[val_name]["stats"]["loss"].append(
                validation_result["loss"])
            validation_results_dict[val_name]["stats"]["perplexity"].append(
                validation_result["perplexity"])
            logger.end(val_text)

        with logger.spinner("Checkpoint", placement="right"):
            # Multi-gpu models wrap the original model. To make the checkpoint
            # compatible with the original model, the state dict of .module is saved.
            model_unwrapped = (model.module if isinstance(
                model, DistributedDataParallel) else model)
            save_checkpoint(
                logger,
                model_unwrapped,
                tokeniser,
                stats=OrderedDict(
                    epoch=actual_epoch,
                    train=train_stats,
                    validation=validation_results_dict,
                    outdated_validation=outdated_validations,
                    model=OrderedDict(kind=model_kind),
                ),
                step=actual_epoch,
            )

        with logger.spinner("Logging Data", placement="right"):
            log_results(
                logger,
                actual_epoch,
                OrderedDict(lr=epoch_lr, stats=train_result),
                validation_results,
                model_unwrapped,
            )

        with logger.spinner("Best Checkpoints", placement="right"):
            val_stats = OrderedDict({
                val_name: {
                    "name": val_name,
                    "start": val_result["start"],
                    "stats": val_result["stats"],
                }
                for val_name, val_result in validation_results_dict.items()
            })
            log_top_checkpoints(logger, val_stats, metrics)

        time_difference = time.time() - start_time
        epoch_results = [OrderedDict(name="Train", stats=train_result)
                         ] + validation_results
        log_epoch_stats(logger,
                        epoch_results,
                        metrics,
                        lr=epoch_lr,
                        time_elapsed=time_difference)
        logger.end(epoch_text, prefix=False)
def train(
    enc,
    dec,
    optimiser,
    criterion,
    data_loader,
    device,
    teacher_forcing_ratio=teacher_forcing_ratio,
    lr_scheduler=None,
    num_epochs=100,
    print_epochs=None,
    checkpoint=default_checkpoint,
    prefix="",
    max_grad_norm=max_grad_norm,
):
    if print_epochs is None:
        print_epochs = num_epochs

    writer = init_tensorboard(name=prefix.strip("-"))
    start_epoch = checkpoint["epoch"]
    accuracy = checkpoint["accuracy"]
    losses = checkpoint["losses"]
    learning_rates = checkpoint["lr"]
    grad_norms = checkpoint["grad_norm"]
    optim_params = [
        p for param_group in optimiser.param_groups
        for p in param_group["params"]
    ]

    for epoch in range(num_epochs):
        start_time = time.time()
        epoch_losses = []
        epoch_grad_norms = []
        epoch_correct_symbols = 0
        total_symbols = 0

        if lr_scheduler:
            lr_scheduler.step()

        epoch_text = "[{current:>{pad}}/{end}] Epoch {epoch}".format(
            current=epoch + 1,
            end=num_epochs,
            epoch=start_epoch + epoch + 1,
            pad=len(str(num_epochs)),
        )

        with tqdm(
                desc=epoch_text,
                total=len(data_loader.dataset),
                dynamic_ncols=True,
                leave=False,
        ) as pbar:
            for d in data_loader:
                input = d["image"].to(device)
                # The last batch may not be a full batch
                curr_batch_size = len(input)
                expected = d["truth"]["encoded"].to(device)
                batch_max_len = expected.size(1)
                # Replace -1 with the PAD token
                expected[expected == -1] = data_loader.dataset.token_to_id[PAD]
                enc_low_res, enc_high_res = enc(input)
                # Decoder needs to be reset, because the coverage attention (alpha)
                # only applies to the current image.
                dec.reset(curr_batch_size)
                hidden = dec.init_hidden(curr_batch_size).to(device)
                # Starts with a START token
                sequence = torch.full(
                    (curr_batch_size, 1),
                    data_loader.dataset.token_to_id[START],
                    dtype=torch.long,
                    device=device,
                )
                # The teacher forcing is done per batch, not symbol
                use_teacher_forcing = random.random() < teacher_forcing_ratio
                decoded_values = []
                for i in range(batch_max_len - 1):
                    previous = (expected[:, i]
                                if use_teacher_forcing else sequence[:, -1])
                    previous = previous.view(-1, 1)
                    out, hidden = dec(previous, hidden, enc_low_res,
                                      enc_high_res)
                    hidden = hidden.detach()
                    _, top1_id = torch.topk(out, 1)
                    sequence = torch.cat((sequence, top1_id), dim=1)
                    decoded_values.append(out)

                decoded_values = torch.stack(decoded_values, dim=2).to(device)
                optimiser.zero_grad()
                # decoded_values does not contain the start symbol
                loss = criterion(decoded_values, expected[:, 1:])
                loss.backward()
                # Clip gradients, it returns the total norm of all parameters
                grad_norm = nn.utils.clip_grad_norm_(optim_params,
                                                     max_norm=max_grad_norm)
                optimiser.step()

                epoch_losses.append(loss.item())
                epoch_grad_norms.append(grad_norm)
                epoch_correct_symbols += torch.sum(sequence == expected,
                                                   dim=(0, 1)).item()
                total_symbols += expected.numel()
                pbar.update(curr_batch_size)

        mean_epoch_loss = np.mean(epoch_losses)
        mean_epoch_grad_norm = np.mean(epoch_grad_norms)
        losses.append(mean_epoch_loss)
        grad_norms.append(mean_epoch_grad_norm)
        epoch_accuracy = epoch_correct_symbols / total_symbols
        accuracy.append(epoch_accuracy)
        epoch_lr = lr_scheduler.get_lr()[0]
        learning_rates.append(epoch_lr)

        save_checkpoint(
            {
                "epoch": start_epoch + epoch + 1,
                "losses": losses,
                "accuracy": accuracy,
                "lr": learning_rates,
                "grad_norm": grad_norms,
                "model": {
                    "encoder": enc.state_dict(),
                    "decoder": dec.state_dict()
                },
                "optimiser": optimiser.state_dict(),
            },
            prefix=prefix,
        )

        elapsed_time = time.time() - start_time
        elapsed_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
        if epoch % print_epochs == 0 or epoch == num_epochs - 1:
            print("{epoch_text}: "
                  "Accuracy = {accuracy:.5f}, "
                  "Loss = {loss:.5f}, "
                  "lr = {lr} "
                  "(time elapsed {time})".format(
                      epoch_text=epoch_text,
                      accuracy=epoch_accuracy,
                      loss=mean_epoch_loss,
                      lr=epoch_lr,
                      time=elapsed_time,
                  ))
            write_tensorboard(
                writer,
                start_epoch + epoch + 1,
                mean_epoch_loss,
                epoch_accuracy,
                mean_epoch_grad_norm,
                enc,
                dec,
            )

    return np.array(losses), np.array(accuracy)
Exemple #13
0
def main():
    #I must know that atleast something is running
    print("Please wait while I train")
    
    args = parse_args()
    #path of data directories
    data_dir = 'flowers'
    train_dir = data_dir + '/train'
    val_dir = data_dir + '/valid'
    test_dir = data_dir + '/test'
    
    #transformations to be applied on dataset
    train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])
    test_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], 
                                                                  [0.229, 0.224, 0.225])])
    val_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], 
                                                                  [0.229, 0.224, 0.225])])
    
    # TODO: Load the datasets with ImageFolder
    train_datasets = datasets.ImageFolder(train_dir, transform=train_transforms)
    test_datasets = datasets.ImageFolder(test_dir, transform=test_transforms)
    val_datasets = datasets.ImageFolder(val_dir, transform=val_transforms)
    
    # TODO: Using the image datasets and the trainforms, define the dataloaders
    trainloader = torch.utils.data.DataLoader(train_datasets, batch_size = 64, shuffle=True)
    valloader = torch.utils.data.DataLoader(val_datasets, batch_size = 64, shuffle=True)
    testloader = torch.utils.data.DataLoader(test_datasets, batch_size = 64, shuffle=True)
    
    #print(summary(trainloaders))
    #image, label = next(iter(trainloader))
    #helper.imshow(image[0,:]);
    
    #defining parameters that will be passed as default to the model under training
    
    model = getattr(models, args.arch)(pretrained=True)
    
    #choose out of two models
    if args.arch == 'vgg13':
    # TODO: Build and train your network
        model = models.vgg13(pretrained=True)
        print(model)
        for param in model.parameters():
            param.requires_grad = False
    
        classifier = nn.Sequential(nn.Linear(25088, 4096),
                               nn.Dropout(p=0.2),
                               nn.ReLU(),
                               nn.Linear(4096, 4096),
                               nn.ReLU(),
                               nn.Dropout(p=0.2),
                               nn.Linear(4096,102),
                               nn.LogSoftmax(dim=1))
        model.classifier= classifier
   
    elif args.arch == 'densenet121':
        model = models.densenet121(pretrained=True)
        print(model)
        for param in model.parameters():
            param.requires_grad = False
    
        classifier = nn.Sequential(nn.Linear(1024, 512),
                               nn.Dropout(p=0.6),
                               nn.ReLU(),
                               nn.Linear(512, 256),
                               nn.ReLU(),
                               nn.Dropout(p=0.6),                               
                               nn.Linear(256,102),
                               nn.LogSoftmax(dim=1))
        model.classifier = classifier
    
    model.classifier = classifier
    criterion = nn.NLLLoss()
    epochs = int(args.epochs)
    learning_rate = float(args.learning_rate)
    print_every = int(args.print_every)
    optimizer = optim.Adam(model.classifier.parameters(), lr=learning_rate)
    train(model, criterion, epochs, optimizer, print_every, trainloader, valloader)
    model.class_to_idx = train_datasets.class_to_idx        
    path = args.save_dir
    save_checkpoint(args, model, optimizer, learning_rate, epochs, path)
Exemple #14
0
                wandb.log({
                    "LR": scheduler.get_last_lr()[0],
                    "Throughput": step_throughput,
                    "Loss": mean_loss,
                    "Accuracy": acc
                })

            if current_step + 1 == training_steps:
                break  # Training finished mid-epoch
            save_every = current_step % config.checkpoint_steps == 0
            not_finished = (current_step + 1 != training_steps)
            if config.checkpoint_output_dir and save_every and not_finished:
                model.deparallelize()
                save_checkpoint(config,
                                model,
                                optimizer,
                                current_step,
                                metrics={"Loss": mean_loss})
                model.parallelize()

    stop_train = time.perf_counter()

    if config.checkpoint_output_dir:
        # Checkpoint at end of run
        model.deparallelize()
        save_checkpoint(config, model, optimizer, training_steps)
    logger.info("---------------------------------------")

    logger.info("---------- Training Metrics -----------")
    logger.info(f"global_batch_size: {config.global_batch_size}")
    logger.info(f"batches_per_step: {config.batches_per_step}")
Exemple #15
0
def main(args):

    if len(args.gpu_ids) > 0:
        assert(torch.cuda.is_available())
        cudnn.benchmark = True
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        args.device = torch.device("cuda:{}".format(args.gpu_ids[0]))
    else:
        kwargs = {}
        args.device = torch.device("cpu")

    normlizer = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    
    print("Building dataset: " + args.dataset)

    if args.dataset == "cifar10":
        args.num_class = 10
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.dataset_dir, train=True, download=True,
                        transform=transforms.Compose([
                            transforms.Pad(4),
                            transforms.RandomCrop(32),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normlizer])),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.dataset_dir, train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            normlizer])),
            batch_size=100, shuffle=False, **kwargs)

    elif args.dataset == "cifar100":
        args.num_class = 100
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(args.dataset_dir, train=True, download=True,
                        transform=transforms.Compose([
                            transforms.Pad(4),
                            transforms.RandomCrop(32),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normlizer])), 
                        batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(args.dataset_dir, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normlizer])),
                batch_size=100, shuffle=False, **kwargs)

    net = create_net(args)

    print(net)

    optimizer = optim.SGD(net.parameters(), lr=args.base_lr, momentum=args.beta1, weight_decay=args.weight_decay)

    if args.resume:
        net, optimizer, best_acc, start_epoch = load_checkpoint(args, net, optimizer)
    else:
        start_epoch = 0
        best_acc = 0

    x = torch.randn(1, 3, 32, 32)
    flops, params = profile(net, inputs=(x,))

    print("Number of params: %.6fM" % (params / 1e6))
    print("Number of FLOPs: %.6fG" % (flops / 1e9))

    args.log_file.write("Network - " + args.arch + "\n")
    args.log_file.write("Attention Module - " + args.attention_type + "\n")
    args.log_file.write("Params - %.6fM" % (params / 1e6) + "\n")
    args.log_file.write("FLOPs - %.6fG" % (flops / 1e9) + "\n")
    args.log_file.write("--------------------------------------------------" + "\n")

    if len(args.gpu_ids) > 0:
        net.to(args.gpu_ids[0])
        net = torch.nn.DataParallel(net, args.gpu_ids)  # multi-GPUs

    for epoch in range(start_epoch, args.num_epoch):
        # if args.wrn:
            # adjust_learning_rate_wrn(optimizer, epoch, args.warmup)
        # else:
        adjust_learning_rate(optimizer, epoch, args.warmup)

        train(net, optimizer, epoch, train_loader, args)
        epoch_acc = validate(net, epoch, test_loader, args)

        is_best = epoch_acc > best_acc
        best_acc = max(epoch_acc, best_acc)

        save_checkpoint({
            "epoch": epoch + 1,
            "arch": args.arch,
            "state_dict": net.module.cpu().state_dict(),
            "best_acc": best_acc,
            "optimizer" : optimizer.state_dict(),
            }, is_best, epoch, save_path=args.ckpt)

        net.to(args.device)

        args.log_file.write("--------------------------------------------------" + "\n")

    args.log_file.write("best accuracy %4.2f" % best_acc)

    print("Job Done!")
Exemple #16
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    model = create_net(args)

    x = torch.randn(1, 3, 224, 224)
    flops, params = profile(model, inputs=(x, ))

    print("model [%s] - params: %.6fM" % (args.arch, params / 1e6))
    print("model [%s] - FLOPs: %.6fG" % (args.arch, flops / 1e9))

    log_file = os.path.join(args.ckpt, "log.txt")

    if os.path.exists(log_file):
        args.log_file = open(log_file, mode="a")
    else:
        args.log_file = open(log_file, mode="w")
        args.log_file.write("Network - " + args.arch + "\n")
        args.log_file.write("Attention Module - " + args.attention_type + "\n")
        args.log_file.write("Params - " % str(params) + "\n")
        args.log_file.write("FLOPs - " % str(flops) + "\n")
        args.log_file.write(
            "--------------------------------------------------" + "\n")

    args.log_file.close()

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.device)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.device)
        model = model.to(args.gpu[0])
        model = torch.nn.DataParallel(model, args.gpu)

    print(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.resume:
        model, optimizer, best_acc1, start_epoch = load_checkpoint(
            args, model, optimizer)
        args.start_epoch = start_epoch

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.save_weights is not None:  # "deparallelize" saved weights
        print("=> saving 'deparallelized' weights [%s]" % args.save_weights)
        model = model.module
        model = model.cpu()
        torch.save({'state_dict': model.state_dict()},
                   args.save_weights,
                   _use_new_zipfile_serialization=False)
        return

    if args.evaluate:
        args.log_file = open(log_file, mode="a")
        validate(val_loader, model, criterion, args)
        args.log_file.close()
        return

    if args.cos_lr:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs)
        for epoch in range(args.start_epoch):
            scheduler.step()

    for epoch in range(args.start_epoch, args.epochs):

        args.log_file = open(log_file, mode="a")

        if args.distributed:
            train_sampler.set_epoch(epoch)

        if (not args.cos_lr):
            adjust_learning_rate(optimizer, epoch, args)
        else:
            scheduler.step()
            print('[%03d] %.5f' % (epoch, scheduler.get_lr()[0]))

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        args.log_file.close()

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc": best_acc1,
                    "optimizer": optimizer.state_dict(),
                },
                is_best,
                epoch,
                save_path=args.ckpt)
Exemple #17
0
        next_state, reward, done, info = env.step(action)
        agent_reward = clip_reward(reward)
        agent.add_memory(cur_state, action, agent_reward, next_state, done)

        score += reward
        agent_score += agent_reward

        loss = agent.optimize_model()
        if loss is not None:
            model_updates += 1
            delta = loss - mean_loss
            mean_loss += delta / model_updates

        cur_state = next_state

        time_step += 1
        total_steps += 1
        # if time_step % 100 == 0:
        #   print("Completed iteration", time_step)

    print(
        "Episode {} score: {}, agent score: {}, total steps taken: {}, epsilon: {}"
        .format(i_episode, score, agent_score, total_steps, epsilon))
    progress.append(
        (time_step, total_steps, score, agent_score, mean_loss, value))
    # print("Progress is", progress)
    if CKPT_ENABLED and score > max_score:
        max_score = score
        save_checkpoint(progress, dqn_online, dqn_target, optimizer,
                        CKPT_FILENAME)
Exemple #18
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=32,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # load checkpoint if file exist.
    start_point = 0
    if args.use_latest_checkpoint:
        files = glob.glob(f'{args.model_save_path}/checkpoint_*.json')
        if len(files) != 0:
            index = max([
                int(n) for n in
                [re.sub(r'.*checkpoint_(\d+).json', '\\1', f) for f in files]
            ])
            # load weights and solver state info from specified checkpoint file.
            start_point = load_checkpoint(
                f'{args.model_save_path}/checkpoint_{index}.json', solver)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator

    # If the data does not exist, it will try to download it from the server
    # and prepare it. When executing multiple processes on the same host, it is
    # necessary to execute initial data preparation by the representative
    # process (local_rank is 0) on the host.

    # Prepare data only when local_rank is 0
    if mpi_rank == 0:
        rng = np.random.RandomState(device_id)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(args.batch_size, False)

    # Wait for data to be prepared without watchdog
    comm.barrier()

    # Prepare data when local_rank is not 0
    if mpi_rank != 0:
        rng = np.random.RandomState(device_id)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(args.batch_size, False)

    # loss_error_train.forward()

    # Training-loop
    ve = nn.Variable()
    model_save_interval = 0
    for i in range(start_point, int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if mpi_rank == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if model_save_interval <= 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))
                    save_checkpoint(args.model_save_path, i, solver)
                    model_save_interval += int(args.model_save_interval /
                                               n_devices)
        model_save_interval -= 1

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if mpi_rank == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

        # exit(0)

    if mpi_rank == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
    comm.barrier()
Exemple #19
0
def train(args):
    """
    Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    Steps:
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Load checkpoint to resume previous training.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Create Communicator and Context
    comm = create_communicator(ignore_error=True)
    if comm:
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = comm.local_rank
    else:
        n_devices = 1
        mpi_rank = 0
        device_id = args.device_id

    if args.context == 'cpu':
        import nnabla_ext.cpu
        context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cudnn
        context = nnabla_ext.cudnn.context(device_id=device_id)
    nn.set_default_context(context)

    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size
    iter_per_epoch = int(n_train_samples / args.batch_size / n_devices)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)

    # Create validation graphs
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((bs_valid, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = iter_per_epoch * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # load checkpoint if file exist.
    start_point = 0
    if args.use_latest_checkpoint:
        files = glob.glob(f'{args.model_save_path}/checkpoint_*.json')
        if len(files) != 0:
            index = max([
                int(n) for n in
                [re.sub(r'.*checkpoint_(\d+).json', '\\1', f) for f in files]
            ])
            # load weights and solver state info from specified checkpoint file.
            start_point = load_checkpoint(
                f'{args.model_save_path}/checkpoint_{index}.json', solver)
        print(f'checkpoint is loaded. start iteration from {start_point}')

    # Create monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator

    # If the data does not exist, it will try to download it from the server
    # and prepare it. When executing multiple processes on the same host, it is
    # necessary to execute initial data preparation by the representative
    # process (rank is 0) on the host.

    # Download dataset by rank-0 process
    if single_or_rankzero():
        rng = np.random.RandomState(mpi_rank)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(bs_valid, False)

    # Wait for data to be prepared without watchdog
    if comm:
        comm.barrier()

    # Prepare dataset for remaining process
    if not single_or_rankzero():
        rng = np.random.RandomState(mpi_rank)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(bs_valid, False)

    # Training-loop
    ve = nn.Variable()
    for i in range(start_point // n_devices, args.epochs * iter_per_epoch):
        # Validation
        if i % iter_per_epoch == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                image_valid.d = image
                label_valid.d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            if comm:
                comm.all_reduce(ve.data, division=True, inplace=True)

            # Monitoring error and elapsed time
            if single_or_rankzero():
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Save model
        if single_or_rankzero():
            if i % (args.model_save_interval // n_devices) == 0:
                iter = i * n_devices
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 'params_%06d.h5' % iter))
                if args.use_latest_checkpoint:
                    save_checkpoint(args.model_save_path, iter, solver)

        # Forward/Zerograd
        image, label = tdata.next()
        image_train.d = image
        label_train.d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        # Monitoring loss, error and elapsed time
        if single_or_rankzero():
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

    # Save nnp last epoch
    if single_or_rankzero():
        runtime_contents = {
            'networks': [{
                'name': 'Validation',
                'batch_size': args.batch_size,
                'outputs': {
                    'y': pred_valid
                },
                'names': {
                    'x': image_valid
                }
            }],
            'executors': [{
                'name': 'Runtime',
                'network': 'Validation',
                'data': ['x'],
                'output': ['y']
            }]
        }
        iter = args.epochs * iter_per_epoch
        nn.save_parameters(
            os.path.join(args.model_save_path, 'params_%06d.h5' % iter))
        nnabla.utils.save.save(
            os.path.join(args.model_save_path, f'{args.net}_result.nnp'),
            runtime_contents)
    if comm:
        comm.barrier()
def train(
    enc,
    dec,
    optimiser,
    criterion,
    train_data_loader,
    validation_data_loader,
    device,
    teacher_forcing_ratio=teacher_forcing_ratio,
    lr_scheduler=None,
    num_epochs=100,
    print_epochs=None,
    checkpoint=default_checkpoint,
    prefix="",
    max_grad_norm=max_grad_norm,
):
    if print_epochs is None:
        print_epochs = num_epochs

    writer = init_tensorboard(name=prefix.strip("-"))
    start_epoch = checkpoint["epoch"]
    train_accuracy = checkpoint["train_accuracy"]
    train_losses = checkpoint["train_losses"]
    validation_accuracy = checkpoint["validation_accuracy"]
    validation_losses = checkpoint["validation_losses"]
    learning_rates = checkpoint["lr"]
    grad_norms = checkpoint["grad_norm"]

    for epoch in range(num_epochs):
        start_time = time.time()

        if lr_scheduler:
            lr_scheduler.step()

        epoch_text = "[{current:>{pad}}/{end}] Epoch {epoch}".format(
            current=epoch + 1,
            end=num_epochs,
            epoch=start_epoch + epoch + 1,
            pad=len(str(num_epochs)),
        )

        train_result = run_epoch(
            train_data_loader,
            enc,
            dec,
            epoch_text,
            criterion,
            optimiser,
            teacher_forcing_ratio,
            max_grad_norm,
            device,
            train=True,
        )
        train_losses.append(train_result["loss"])
        grad_norms.append(train_result["grad_norm"])
        train_epoch_accuracy = (train_result["correct_symbols"] /
                                train_result["total_symbols"])
        train_accuracy.append(train_epoch_accuracy)
        epoch_lr = lr_scheduler.get_lr()[0]
        learning_rates.append(epoch_lr)

        validation_result = run_epoch(
            validation_data_loader,
            enc,
            dec,
            epoch_text,
            criterion,
            optimiser,
            teacher_forcing_ratio,
            max_grad_norm,
            device,
            train=False,
        )
        validation_losses.append(validation_result["loss"])
        validation_epoch_accuracy = (validation_result["correct_symbols"] /
                                     validation_result["total_symbols"])
        validation_accuracy.append(validation_epoch_accuracy)

        save_checkpoint(
            {
                "epoch": start_epoch + epoch + 1,
                "train_losses": train_losses,
                "train_accuracy": train_accuracy,
                "validation_losses": validation_losses,
                "validation_accuracy": validation_accuracy,
                "lr": learning_rates,
                "grad_norm": grad_norms,
                "model": {
                    "encoder": enc.state_dict(),
                    "decoder": dec.state_dict()
                },
                "optimiser": optimiser.state_dict(),
            },
            prefix=prefix,
        )

        elapsed_time = time.time() - start_time
        elapsed_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
        if epoch % print_epochs == 0 or epoch == num_epochs - 1:
            print(("{epoch_text}: "
                   "Train Accuracy = {train_accuracy:.5f}, "
                   "Train Loss = {train_loss:.5f}, "
                   "Validation Accuracy = {validation_accuracy:.5f}, "
                   "Validation Loss = {validation_loss:.5f}, "
                   "lr = {lr} "
                   "(time elapsed {time})").format(
                       epoch_text=epoch_text,
                       train_accuracy=train_epoch_accuracy,
                       train_loss=train_result["loss"],
                       validation_accuracy=validation_epoch_accuracy,
                       validation_loss=validation_result["loss"],
                       lr=epoch_lr,
                       time=elapsed_time,
                   ))
            write_tensorboard(
                writer,
                start_epoch + epoch + 1,
                train_result["grad_norm"],
                train_result["loss"],
                train_epoch_accuracy,
                validation_result["loss"],
                validation_epoch_accuracy,
                enc,
                dec,
            )
Exemple #21
0
def trian_validate_with_scheduling(args,
                                   net,
                                   optimizer,
                                   compress_scheduler,
                                   device,
                                   dataloaders,
                                   dataset_sizes,
                                   loggers,
                                   tracker,
                                   epoch=1,
                                   validate=True,
                                   verbose=True):
    # Whtat's collectors_context
    # At first, we need to specify the model name, and its learning progress:
    #if not os.path.isdir(args.output_dir):
    #    os.mkdir(args.output_dir)
    #os.mkdir(args.output_dir, exist_ok=True)

    name = args.name
    if args.name == '':
        name = args.arch + "_" + args.dataset
    # Must exist pruning mode.
    # Reset learning rate and momentum buffer in the optimizer for next learning stage!
    # Should know whether the learning rate decay is based on epochs or steps
    # Or more, the meaning of last epoch argument indicates current epoch.
    #****
    # This line may raise the problems that the epoch doesn't exist any policy....
    #****
    if compress_scheduler:
        if compress_scheduler.prune_mechanism:
            if epoch == (compress_scheduler.pruner_info['max_epoch']):
                # Reset optimizer and learning rate in retrain phase.
                # NOTE: We should specify the true
                for index in range(len(compress_scheduler.policies[epoch])):
                    policy_name = compress_scheduler.policies[epoch][
                        index].__class__.__name__.split("Policy")[0]
                    if policy_name == "LR":
                        compress_scheduler.policies[epoch][
                            index].lr_scheduler.optimizer.param_groups[0][
                                'lr'] = args.lr_retrain
                        compress_scheduler.policies[epoch][
                            index].lr_scheduler.base_lrs = [args.lr_retrain]
                        compress_scheduler.policies[epoch][
                            index].lr_scheduler.optimizer.param_groups[0][
                                'momentum'] = 0.9
                        compress_scheduler.policies[epoch][
                            index].lr_scheduler.optimizer.param_groups[0][
                                'initial_lr'] = args.lr_retrain
                        for group in optimizer.param_groups:
                            for p in group['params']:
                                if 'momentum_buffer' in optimizer.state[p]:
                                    optimizer.state[p].pop(
                                        'momentum_buffer', None)
                        break

            if epoch == (compress_scheduler.pruner_info['min_epoch']):
                # *****
                # NOTE If not using ADMM pruner, do we need to reset lr scheduler in this loop?
                # *****
                # Reset learning rate and momentum buffer for pruning stage!
                policy_name = compress_scheduler.policies[epoch][
                    0].__class__.__name__.split("Policy")[0]
                #if policy_name != "ADMM":
                #    compress_scheduler.policies[epoch][0].lr_scheduler.optimizer.param_groups[0]['lr'] = args.lr_prune
                #    compress_scheduler.policies[epoch][0].lr_scheduler.base_lrs = [args.lr_prune]
                #    compress_scheduler.policies[epoch][0].lr_scheduler.optimizer.param_groups[0]['momentum'] = 0.9
                #    compress_scheduler.policies[epoch][0].lr_scheduler.optimizer.param_groups[0]['initial_lr'] = args.lr_prune
                for group in optimizer.param_groups:
                    group['lr'] = args.lr_prune
                    group['initial_lr'] = args.lr_prune
                    # for group in optimizer.param_groups:
                    for p in group['params']:
                        if 'momentum_buffer' in optimizer.state[p]:
                            optimizer.state[p].pop('momentum_buffer', None)

            if epoch >= compress_scheduler.pruner_info['max_epoch']:
                name += "_retrain"

            elif epoch < compress_scheduler.pruner_info['min_epoch']:
                name += "_pretrain"

            else:
                name += "_prune"
    else:
        # Only proceed with pre-train or re-train phase model.
        name = name + "_" + args.stage

    if compress_scheduler:
        #dataset_name = 'val' if args.split_ratio != 0 else 'test'
        dataset_name = 'test'
        #data_group, model, criterion, device, num_classes, loggers, epoch=-1, noise_factor=0)

        # These two partial function is created for the channel pruning methods (Prof. Han):
        forward_fn = partial(_validate,
                             args=args,
                             dataloaders=dataloaders,
                             data_group=dataset_name,
                             device=device,
                             loggers=loggers,
                             epoch=-1,
                             noise_factor=0)

        #light_train_with_distiller(model, criterion, optimizer, compress_scheduler, device, num_classes,
        #                       dataset_sizes, loggers, epoch=1)

        train_fn = partial(light_train_with_distiller,
                           args=args,
                           optimizer=optimizer,
                           compress_scheduler=compress_scheduler,
                           device=device,
                           dataloaders=dataloaders,
                           dataset_sizes=dataset_sizes,
                           loggers=loggers,
                           epoch=0)

        compress_scheduler.on_epoch_begin(epoch,
                                          optimizer,
                                          forward_fn=forward_fn,
                                          train_fn=train_fn)

    nat_loss, overall_loss = light_train_with_distiller(
        args, net, optimizer, compress_scheduler, device, dataloaders,
        dataset_sizes, loggers, epoch)

    if validate:
        nat_loss = _validate(args,
                             dataloaders,
                             'test',
                             net,
                             device,
                             loggers,
                             epoch=epoch)

    if compress_scheduler:
        # Or we can compute IoU here?
        loss = nat_loss
        #top1 = nat_loss
        compress_scheduler.on_epoch_end(epoch,
                                        optimizer,
                                        metrics={'min': loss})  #, 'max':top1})

    is_best, checkpoint_extras = _finalize_epoch(args,
                                                 net,
                                                 tracker,
                                                 epoch,
                                                 top1=nat_loss)
    # Check whether the out direcotry is already built.
    ckpt.save_checkpoint(epoch,
                         args.arch,
                         net,
                         optimizer=optimizer,
                         scheduler=compress_scheduler,
                         extras=checkpoint_extras,
                         is_best=is_best,
                         name=name,
                         dir=msglogger.logdir)

    #return top1, top5, loss, tracker
    return nat_loss, tracker
Exemple #22
0
def train_network(trainloader, validloader, args, train_data=None):
    #build classifier with the specific arguments
    model, optimizer, criterion, device = build_classifier(
        args.arch, args.hidden_units, args.learning_rate, args.gpu)
    running_loss, steps = 0, 0
    print_every = 40
    model.train()
    print('Training the network')
    for epoch in range(args.epochs):
        for inputs, labels in trainloader:
            steps += 1

            #move to the right device
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()  #zero the optimizer gradient value

            log_pred = model(inputs)  #calculate model prediction

            loss = criterion(log_pred, labels)  #calculate model loss

            loss.backward()  #backpropagte the loss

            optimizer.step()  #update the weights and bias

            running_loss += loss.item()  # sum the loss
            if steps % print_every == 0:
                valid_loss = 0
                accuracy = 0
                model.eval(
                )  # enable model evaluation mode for faster calculations
                with torch.no_grad(
                ):  # also no need to propagate gradients, reduce memory consumption for computations
                    for inputs, labels in validloader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        v_pred = model.forward(inputs)
                        batch_loss = criterion(v_pred, labels)
                        valid_loss += batch_loss.item()

                        #calculate model accuracy
                        ps = torch.exp(v_pred)
                        top_probs, top_labels = ps.topk(1, dim=1)
                        equals = top_labels == labels.view(*top_labels.shape)
                        accuracy += torch.mean(equals.type(
                            torch.FloatTensor)).item()

                print(
                    "Epoch: {}/{}.. ".format(epoch + 1, args.epochs),
                    "Training Loss: {:.3f}.. ".format(running_loss /
                                                      print_every),
                    "Test Loss: {:.3f}..".format(valid_loss /
                                                 len(validloader)),
                    "Test accuracy: {:.3f}..".format(accuracy /
                                                     len(validloader) * 100))

                running_loss = 0

                model.train()

    print("Finished the training and test!")
    if train_data != None:
        checkpoint.save_checkpoint(model, optimizer, args, train_data)
        print("Checkpoint saved!")
def train(
    model=None,
    dataloader_primary=None,
    dataloader_aux=None,
    dataloader_val=None,
    optimizer=None,
    completed_epochs=None,
    loss=None,
    args=None,
):
    """Train a model with a ResNet18 feature extractor on data from the primary and auxiliary domain(s), adapted from: https://github.com/fungtion/DANN_py3/blob/master/main.py"""
    wandb.init(config=args, project=args.project, name=args.run_name)
    best_acc = 0.0
    best_epoch_loss_label_primary = sys.float_info.max if loss is None else loss
    val_acc_history = []

    for epoch in range(completed_epochs + 1,
                       completed_epochs + args.epochs + 1):
        len_dataloader = len(dataloader_primary)
        if dataloader_aux is not None and args.mode == "dann":
            len_dataloader = min(len(dataloader_primary), len(dataloader_aux))
            data_aux_iter = iter(dataloader_aux)
        data_primary_iter = iter(dataloader_primary)
        loss_label_classifier = torch.nn.CrossEntropyLoss()
        loss_domain_classifier = torch.nn.CrossEntropyLoss()

        running_loss_label_primary = 0

        model.train()
        for i in range(1, len_dataloader + 1):
            # Training with primary data
            data_primary = data_primary_iter.next()
            primary_img, primary_label = data_primary
            primary_img, primary_label = (
                primary_img.to(device),
                primary_label.to(device),
            )
            model.zero_grad()
            primary_domain = torch.zeros_like(primary_label)
            primary_domain = primary_domain.to(device)
            class_output, domain_output = model(primary_img)
            preds = class_output.argmax(dim=1)
            batch_acc_train = preds.eq(
                primary_label).sum().item() / primary_label.size(0)
            domain_preds = domain_output.argmax(dim=1)
            batch_acc_domain_primary = domain_preds.eq(
                primary_domain).sum().item() / primary_domain.size(0)

            loss_primary_label = loss_label_classifier(class_output,
                                                       primary_label)
            running_loss_label_primary += loss_primary_label.data.cpu().item()
            loss_primary_domain = loss_domain_classifier(
                domain_output, primary_domain)

            # Training with auxiliary data
            loss_aux_domain = torch.FloatTensor(1).fill_(0)
            batch_acc_domain_aux = 0
            if dataloader_aux is not None and args.mode == "dann":
                data_aux = data_aux_iter.next()
                aux_img, aux_label = data_aux
                aux_img, aux_label = aux_img.to(device), aux_label.to(device)
                aux_domain = torch.ones_like(aux_label)
                aux_domain = aux_domain.to(device)
                _, domain_output = model(aux_img)
                domain_preds = domain_output.argmax(dim=1)
                batch_acc_domain_aux = domain_preds.eq(
                    aux_domain).sum().item() / aux_domain.size(0)
                loss_aux_domain = loss_domain_classifier(
                    domain_output, aux_domain)

            if args.mode == "dann":
                loss = loss_primary_label + loss_aux_domain + loss_primary_domain
            else:
                loss = loss_primary_label
            loss.backward()
            optimizer.step()

            if i % args.log_interval == 0:
                print(
                    "epoch: %d, [iter: %d / all %d], loss_primary_label: %f, loss_primary_domain: %f, loss_aux_domain: %f, acc_primary_label_batch: %f, acc_primary_domain_batch: %f, acc_aux_domain_batch: %f"
                    % (
                        epoch,
                        i,
                        len_dataloader,
                        loss_primary_label.data.cpu().item(),
                        loss_primary_domain.data.cpu().item(),
                        loss_aux_domain.data.cpu().item(),
                        batch_acc_train,
                        batch_acc_domain_primary,
                        batch_acc_domain_aux,
                    ))
            wandb.log({
                "loss_primary_label":
                loss_primary_label.data.cpu().item(),
                "loss_primary_domain":
                loss_primary_domain.data.cpu().item(),
                "loss_aux_domain":
                loss_aux_domain.data.cpu().item(),
                "acc_primary_label_batch":
                batch_acc_train,
                "acc_primary_domain_batch":
                batch_acc_domain_primary,
                "acc_aux_domain_batch":
                batch_acc_domain_aux,
            })

        epoch_loss_label_primary = running_loss_label_primary / len_dataloader
        print("epoch: %d, loss_primary_label: %f" %
              (epoch, epoch_loss_label_primary))

        if dataloader_val is not None:
            val_loss, val_acc = test(model, dataloader_val)
            print("val_loss: %f, val_acc: %f" % (val_loss, val_acc))
            wandb.log({
                "val_loss_label": val_loss,
                "val_acc_label": val_acc,
            })
            val_acc_history.append(val_acc)
            if val_acc > best_acc:
                best_acc = val_acc

        if epoch_loss_label_primary < best_epoch_loss_label_primary:
            best_epoch_loss_label_primary = epoch_loss_label_primary
            save_checkpoint(
                checkpoint_dir=args.checkpoint_dir,
                run_name=args.run_name,
                checkpoint_name="best.pt",
                model=model,
                epoch=epoch,
                loss=best_epoch_loss_label_primary,
                optimizer=optimizer,
                args=args,
            )
        save_checkpoint(
            checkpoint_dir=args.checkpoint_dir,
            run_name=args.run_name,
            checkpoint_name="latest.pt",
            model=model,
            optimizer=optimizer,
            epoch=epoch,
            loss=best_epoch_loss_label_primary,
            args=args,
        )

    return model, val_acc_history
Exemple #24
0
def train(train_dataloader_X, train_dataloader_Y, 
        test_dataloader_X, test_dataloader_Y, 
        device, n_epochs, balance,
        reconstruction_weight, identity_weight,
        print_every=1, checkpoint_every=10, sample_every=10):
    
    
    # keep track of losses over time
    losses = []

    # Get some fixed data from domains X and Y for sampling. These are images that are held
    # constant throughout training, that allow us to inspect the model's performance.
    fixed_X = next(iter(test_dataloader_X))[0] #test_iter_X.next()[0]
    fixed_Y = next(iter(test_dataloader_Y))[0] #test_iter_Y.next()[0]
    
    # scale to a range -1 to 1
    fixed_X = scale(fixed_X.to(device))
    fixed_Y = scale(fixed_Y.to(device))

    # batches per epoch
    iter_X = iter(train_dataloader_X)
    iter_Y = iter(train_dataloader_Y)
    batches_per_epoch = min(len(iter_X), len(iter_Y))

    for epoch in range(1, n_epochs+1):

        epoch_loss_d_x = 0
        epoch_loss_d_y = 0
        epoch_loss_g = 0

        for _ in range(batches_per_epoch):

            # move images to GPU or CPU depending on what is passed in the device parameter,
            # make sure to scale to a range -1 to 1
            images_X, _ = next(iter_X)
            images_X = scale(images_X.to(device))
            images_Y, _ = next(iter_Y)
            images_Y = scale(images_Y.to(device))


            # ============================================
            #            TRAIN THE DISCRIMINATORS
            # ============================================
            
            ##   First: D_X, real and fake loss components   ##

            # Compute the discriminator losses on real images
            d_x_out = D_X(images_X)
            d_x_loss_real = real_mse_loss(d_x_out)
            
            # Generate fake images that look like domain X based on real images in domain Y
            fake_x = G_YtoX(images_Y)

            # Compute the fake loss for D_X
            d_x_out = D_X(fake_x)
            d_x_loss_fake = fake_mse_loss(d_x_out)
            
            # Compute the total loss
            d_x_loss = d_x_loss_real + d_x_loss_fake
            

            
            ##   Second: D_Y, real and fake loss components   ##
            
            d_y_out = D_Y(images_Y) 
            d_y_real_loss = real_mse_loss(d_y_out)  # D_y disciminator loss on a real Y image
            
            fake_y = G_XtoY(images_X) # generate fake Y image from the real X image
            d_y_out = D_Y(fake_y)
            d_y_fake_loss = fake_mse_loss(d_y_out) # compute D_y loss on a fake Y image
            
            d_y_loss = d_y_real_loss + d_y_fake_loss
            

            d_total_loss = d_x_loss + d_y_loss


            # =========================================
            #            TRAIN THE GENERATORS
            # =========================================

            ##    First: generate fake X images and reconstructed Y images    ##

            # Generate fake images that look like domain X based on real images in domain Y
            fake_x = G_YtoX(images_Y)

            # Compute the generator loss based on domain X
            d_out = D_X(fake_x)
            g_x_loss = real_mse_loss(d_out) # fake X should trick the D_x
            # TODO: consider using MSELoss or SmoothL1Loss (Huber loss)

            # Create a reconstructed y
            y_hat = G_XtoY(fake_x)
                    
            # Compute the cycle consistency loss (the reconstruction loss)
            rec_y_loss = cycle_consistency_loss(images_Y, y_hat, lambda_weight=reconstruction_weight)

            # Conversion from X to X should be an identity mapping
            it_x = G_YtoX(images_X)

            # Compute the identity mapping loss
            it_x_loss = identity_mapping_loss(images_X, it_x, weight=identity_weight)


            ##    Second: generate fake Y images and reconstructed X images    ##
            fake_y = G_XtoY(images_X)
            
            d_out = D_Y(fake_y)
            g_y_loss = real_mse_loss(d_out)  # fake Y should trick the D_y
            
            x_hat = G_YtoX(fake_y)
            
            rec_x_loss = cycle_consistency_loss(images_X, x_hat, lambda_weight=reconstruction_weight)

            it_y = G_XtoY(images_Y)

            it_y_loss = identity_mapping_loss(images_Y, it_y, weight=identity_weight)


            # Add up all generator and reconstructed losses 
            g_total_loss = g_x_loss + g_y_loss + rec_x_loss + rec_y_loss + it_x_loss + it_y_loss
            

            # Perform backprop
            
            if d_total_loss >= balance*g_total_loss:
                d_x_optimizer.zero_grad()
                d_x_loss.backward()
                d_x_optimizer.step()

                d_y_optimizer.zero_grad()
                d_y_loss.backward()
                d_y_optimizer.step()
            
            if g_total_loss >= balance*d_total_loss:
                g_optimizer.zero_grad()
                g_total_loss.backward()
                g_optimizer.step()

            # Gather statistics
            epoch_loss_d_x += d_x_loss.item()
            epoch_loss_d_y += d_y_loss.item()
            epoch_loss_g += g_total_loss.item()


        # Reset the iterators when epoch ends
        iter_X = iter(train_dataloader_X)
        iter_Y = iter(train_dataloader_Y)

        # Print the log info
        if epoch % print_every == 0 or epoch == n_epochs:
            # append real and fake discriminator losses and the generator loss
            losses.append((epoch_loss_d_x, epoch_loss_d_y, epoch_loss_g))
            print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format(
                    epoch, n_epochs, epoch_loss_d_x, epoch_loss_d_y, epoch_loss_g))

            
        # Save the generated samples
        if epoch % sample_every == 0 or epoch == n_epochs:
            G_YtoX.eval() # set generators to eval mode for sample generation
            G_XtoY.eval()
            save_samples(epoch, fixed_Y, fixed_X, G_YtoX, G_XtoY, sample_dir='../samples')
            G_YtoX.train()
            G_XtoY.train()

        
        # Save the model parameters
        if epoch % checkpoint_every == 0 or epoch == n_epochs:
            save_checkpoint(G_XtoY, G_YtoX, D_X, D_Y, '../checkpoints')
            export_script_module(G_XtoY, '../artifacts', 'summer_to_winter_{:05d}.sm'.format(epoch))   
            export_script_module(G_YtoX, '../artifacts', 'winter_to_summer_{:05d}.sm'.format(epoch))             

    return losses
Exemple #25
0
def train(
    enc,
    dec,
    optimiser,
    criterion,
    data_loader,
    device,
    lr_scheduler=None,
    num_epochs=100,
    print_epochs=None,
    checkpoint=default_checkpoint,
    prefix="",
):
    if print_epochs is None:
        print_epochs = num_epochs

    writer = init_tensorboard()
    total_symbols = len(data_loader.dataset) * data_loader.dataset.max_len
    start_epoch = checkpoint["epoch"]
    accuracy = checkpoint["accuracy"]
    losses = checkpoint["losses"]
    learning_rates = checkpoint["lr"]

    for epoch in range(num_epochs):
        start_time = time.time()
        epoch_losses = []
        epoch_correct_symbols = 0

        if lr_scheduler:
            lr_scheduler.step()

        for d in data_loader:
            input = d["image"].to(device)
            # The last batch may not be a full batch
            curr_batch_size = len(input)
            expected = torch.stack(d["truth"]["encoded"], dim=1).to(device)
            enc_low_res, enc_high_res = enc(input)
            # Decoder needs to be reset, because the coverage attention (alpha)
            # only applies to the current image.
            dec.reset(curr_batch_size)
            hidden = dec.init_hidden(curr_batch_size).to(device)
            # Starts with a START token
            sequence = torch.full(
                (curr_batch_size, 1),
                data_loader.dataset.token_to_id[START],
                dtype=torch.long,
                device=device,
            )
            decoded_values = []
            for i in range(data_loader.dataset.max_len - 1):
                previous = sequence[:, -1].view(-1, 1)
                out, hidden = dec(previous, hidden, enc_low_res, enc_high_res)
                _, top1_id = torch.topk(out, 1)
                sequence = torch.cat((sequence, top1_id), dim=1)
                decoded_values.append(out)

            decoded_values = torch.stack(decoded_values, dim=2).to(device)
            optimiser.zero_grad()
            # decoded_values does not contain the start symbol
            loss = criterion(decoded_values, expected[:, 1:])
            loss.backward()
            optimiser.step()

            epoch_losses.append(loss.item())
            epoch_correct_symbols += torch.sum(sequence == expected,
                                               dim=(0, 1)).item()

        mean_epoch_loss = np.mean(epoch_losses)
        losses.append(mean_epoch_loss)
        epoch_accuracy = epoch_correct_symbols / total_symbols
        accuracy.append(epoch_accuracy)
        epoch_lr = lr_scheduler.get_lr()[0]
        learning_rates.append(epoch_lr)

        save_checkpoint(
            {
                "epoch": start_epoch + epoch + 1,
                "losses": losses,
                "accuracy": accuracy,
                "lr": learning_rates,
                "model": {
                    "encoder": enc.state_dict(),
                    "decoder": dec.state_dict()
                },
                "optimiser": optimiser.state_dict(),
            },
            prefix=prefix,
        )

        elapsed_time = time.time() - start_time
        elapsed_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
        if epoch % print_epochs == 0 or epoch == num_epochs - 1:
            print("[{current:>{pad}}/{end}] Epoch {epoch}: "
                  "Accuracy = {accuracy:.5f}, "
                  "Loss = {loss:.5f}, "
                  "lr = {lr} "
                  "(time elapsed {time})".format(
                      current=epoch + 1,
                      end=num_epochs,
                      epoch=start_epoch + epoch + 1,
                      pad=len(str(num_epochs)),
                      accuracy=epoch_accuracy,
                      loss=mean_epoch_loss,
                      lr=epoch_lr,
                      time=elapsed_time,
                  ))
            write_tensorboard(writer, epoch, mean_epoch_loss, epoch_accuracy,
                              enc, dec)

    return np.array(losses), np.array(accuracy)