Ejemplo n.º 1
0
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_init,
                                world_size=args.world_size,
                                rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not (args.distributed and args.local_rank > 0):
        if not path.exists(save_path):
            makedirs(save_path)
        export_args_namespace(args, path.join(save_path, 'config.json'))

    setup_logging(path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = path.join(save_path)
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    model_config = {'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate, map_location="cpu")
        # Overrride configuration with checkpoint info
        args.model = checkpoint.get('model', args.model)
        args.model_config = checkpoint.get('config', args.model_config)
        # load checkpoint
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])

    if args.resume:
        checkpoint_file = args.resume
        if path.isdir(checkpoint_file):
            results.load(path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = path.join(checkpoint_file, 'model_best.pth.tar')
        if path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file, map_location="cpu")
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optim_state_dict = checkpoint.get('optim_state_dict', None)
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
    else:
        optim_state_dict = None

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    if optim_state_dict is not None:
        optimizer.load_state_dict(optim_state_dict)

    trainer = Trainer(model,
                      criterion,
                      optimizer,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      print_freq=args.print_freq,
                      distributed=args.distributed,
                      local_rank=args.local_rank,
                      mixup=args.mixup,
                      cutmix=args.cutmix,
                      loss_scale=args.loss_scale,
                      grad_clip=args.grad_clip,
                      adapt_grad_norm=args.adapt_grad_norm)
    if args.tensorwatch:
        trainer.set_watcher(filename=path.abspath(
            path.join(save_path, 'tensorwatch.log')),
                            port=args.tensorwatch_port)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': False,
                              'input_size': args.input_size,
                              'batch_size': args.eval_batch_size,
                              'shuffle': False,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    if args.evaluate:
        results = trainer.validate(val_data.get_loader())
        logging.info(results)
        return

    # Training Data loading code
    train_data_defaults = {
        'datasets_path': args.datasets_dir,
        'name': args.dataset,
        'split': 'train',
        'augment': True,
        'input_size': args.input_size,
        'batch_size': args.batch_size,
        'shuffle': True,
        'num_workers': args.workers,
        'pin_memory': True,
        'drop_last': True,
        'distributed': args.distributed,
        'duplicates': args.duplicates,
        'autoaugment': args.autoaugment,
        'cutout': {
            'holes': 1,
            'length': 16
        } if args.cutout else None
    }

    if hasattr(model, 'sampled_data_regime'):
        sampled_data_regime = model.sampled_data_regime
        probs, regime_configs = zip(*sampled_data_regime)
        regimes = []
        for config in regime_configs:
            defaults = {**train_data_defaults}
            defaults.update(config)
            regimes.append(DataRegime(None, defaults=defaults))
        train_data = SampledDataRegime(regimes, probs)
    else:
        train_data = DataRegime(getattr(model, 'data_regime', None),
                                defaults=train_data_defaults)

    logging.info('optimization regime: %s', optim_regime)
    logging.info('data regime: %s', train_data)
    args.start_epoch = max(args.start_epoch, 0)
    trainer.training_steps = args.start_epoch * len(train_data)

    start_time = time.time()
    end_time = None
    end_epoch = None
    found = False

    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        # train for one epoch
        train_results = trainer.train(train_data.get_loader(),
                                      chunk_batch=args.chunk_batch)

        # evaluate on validation set
        val_results = trainer.validate(val_data.get_loader())

        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)

        if args.drop_optim_state:
            optim_state_dict = None
        else:
            optim_state_dict = optimizer.state_dict()

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'optim_state_dict': optim_state_dict,
                'best_prec1': best_prec1
            },
            is_best,
            path=save_path,
            save_all=args.save_all)

        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Prec@1 {train[prec1]:.3f} \t'
                     'Training Prec@5 {train[prec5]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Prec@1 {val[prec1]:.3f} \t'
                     'Validation Prec@5 {val[prec5]:.3f} \t\n'.format(
                         epoch + 1, train=train_results, val=val_results))

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)
        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})
        results.add(**values)

        results.plot(x='epoch',
                     y=['training loss', 'validation loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['training error1', 'validation error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['training error5', 'validation error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        if 'grad' in train_results.keys():
            results.plot(x='epoch',
                         y=['training grad'],
                         legend=['gradient L2 norm'],
                         title='Gradient Norm',
                         ylabel='value')
        results.save()

        if not found and val_results['prec1'] > 94:
            found = True
            end_time = time.time() - start_time
            end_epoch = epoch + 1

    if not found:
        end_time = time.time() - start_time
        end_epoch = epoch + 1

    print("Target reached: {}, minutes: {}, epochs: {}".format(
        found, round(end_time / 60, 3), end_epoch))
Ejemplo n.º 2
0
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'


    model_config = {'dataset': args.dataset, 'batch': args.batch_size}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    ##autoname
    fname = auto_name(args, model_config)
    args.save = fname


    monitor = args.monitor

    print(fname)

    save_path = path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init,
                                world_size=args.world_size, rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not (args.distributed and args.local_rank > 0):
        if not args.dry:
            if not path.exists(save_path):
                    makedirs(save_path)
            export_args_namespace(args, path.join(save_path, 'config.json'))


    if monitor > 0 and not args.dry: 

        events_path = "runs/%s" % fname
        my_file = Path(events_path)
        if my_file.is_file():
            os.remove(events_path) 

        writer = SummaryWriter(log_dir=events_path  ,comment=str(args))
        model_config['writer'] = writer
        model_config['monitor'] = monitor
    else:
        monitor = 0
        writer = None

    if args.dry:
        model = models.__dict__[args.model]
        model = model(**model_config)
        print("created model with configuration: %s" % model_config)
        num_parameters = sum([l.nelement() for l in model.parameters()])
        print("number of parameters: %d" % num_parameters)
        return

    setup_logging(path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    model = model(**model_config)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate, map_location="cpu")
        # Overrride configuration with checkpoint info
        args.model = checkpoint.get('model', args.model)
        args.model_config = checkpoint.get('config', args.model_config)
        # load checkpoint
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)",
                     args.evaluate, checkpoint['epoch'])

    if args.resume:
        checkpoint_file = args.resume
        if path.isdir(checkpoint_file):
            results.load(path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = path.join(
                checkpoint_file, 'model_best.pth.tar')
        if path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file, map_location="cpu")
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optim_state_dict = checkpoint.get('optim_state_dict', None)
            logging.info("loaded checkpoint '%s' (epoch %s)",
                         checkpoint_file, checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
    else:
        optim_state_dict = None

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{'epoch': 0,
                                              'optimizer': args.optimizer,
                                              'lr': args.lr,
                                              'momentum': args.momentum,
                                              'weight_decay': args.weight_decay}])

    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    if optim_state_dict is not None:
        optimizer.load_state_dict(optim_state_dict)

    trainer = Trainer(model, criterion, optimizer,
                      device_ids=args.device_ids, device=args.device, dtype=dtype, print_freq=args.print_freq,
                      distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, cutmix=args.cutmix,
                      loss_scale=args.loss_scale, grad_clip=args.grad_clip,  adapt_grad_norm=args.adapt_grad_norm, writer = writer, monitor = monitor)
    if args.tensorwatch:
        trainer.set_watcher(filename=path.abspath(path.join(save_path, 'tensorwatch.log')),
                            port=args.tensorwatch_port)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False,
                                    'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False,
                                    'num_workers': args.workers, 'pin_memory': True, 'drop_last': False})


    if args.evaluate:
        results = trainer.validate(val_data.get_loader())
        logging.info(results)
        return

    # Training Data loading code
    train_data_defaults = {'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True,
                           'input_size': args.input_size,  'batch_size': args.batch_size, 'shuffle': True,
                           'num_workers': args.workers, 'pin_memory': True, 'drop_last': True,
                           'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment,
                           'cutout': {'holes': 1, 'length': 16} if args.cutout else None}

    if hasattr(model, 'sampled_data_regime'):
        sampled_data_regime = model.sampled_data_regime
        probs, regime_configs = zip(*sampled_data_regime)
        regimes = []
        for config in regime_configs:
            defaults = {**train_data_defaults}
            defaults.update(config)
            regimes.append(DataRegime(None, defaults=defaults))
        train_data = SampledDataRegime(regimes, probs)
    else:
        train_data = DataRegime(
            getattr(model, 'data_regime', None), defaults=train_data_defaults)

    logging.info('optimization regime: %s', optim_regime)
    logging.info('data regime: %s', train_data)
    args.start_epoch = max(args.start_epoch, 0)
    trainer.training_steps = args.start_epoch * len(train_data)

    if not args.covmat == "":
        try: 
            int_covmat = int(args.covmat)
            if int_covmat < 0:
                total_layers = len([name for name, layer in model.named_children()])
                int_covmat = total_layers + int_covmat
            child_cnt = 0
        except ValueError:
            int_covmat = None

        def calc_covmat(x_, partitions = 64):

            L = x_.shape[0] // partitions

            non_diags = []
            diags = []
            for p1 in range(partitions):
                for p2 in range(partitions):
                    x = x_[p1*L:(p1+1)*L]
                    y = x_[p2*L:(p2+1)*L]
                    X = torch.matmul(x,y.transpose(0,1))
                    if p1 == p2:
                        mask = torch.eye(X.shape[0],dtype=torch.bool)
                        non_diag = X[~mask].reshape(-1).cpu()
                        diag = X[mask].reshape(-1).cpu()
                        non_diags.append(non_diag)
                        diags.append(diag)
                    else:
                        non_diag = X.reshape(-1).cpu()
                        non_diags.append(diag)
            diags = torch.cat(diags)
            non_diags = torch.cat(non_diags)


            diag_var = diags.var()
            non_diag_var = non_diags.var()


            diags = diags - diags.mean()
            non_diags = non_diags - non_diags.mean()
            ##import pdb; pdb.set_trace()

            diag_small_ratio = (diags < -diags.std()).to(dtype = torch.float).mean() 
            non_diag_small_ratio = (non_diags < -non_diags.std()).to(dtype = torch.float).mean() 

            return diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio


        global diag_var_mean 
        global non_diag_var_mean 
        global var_count 
        var_count = 0
        diag_var_mean = 0 
        non_diag_var_mean = 0

        def report_covmat_hook(module, input, output):
            global diag_var_mean 
            global non_diag_var_mean 
            global var_count 

            flatten_output = output.reshape([-1,1]).detach()
            diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio = calc_covmat(flatten_output)

            diag_var_mean = diag_var_mean + diag_var
            
            non_diag_var_mean = non_diag_var_mean + non_diag_var
            
            var_count = var_count + 1
            if var_count % 10 == 1:
                print("diag_var = %.02f (%.02f), ratio: %.02f , non_diag_var = %0.2f (%.02f), ratio: %.02f" % (diag_var, diag_var_mean/var_count, diag_small_ratio , non_diag_var, non_diag_var_mean/var_count, non_diag_small_ratio ))

        for name, layer in model.named_children():
            if  int_covmat is None:
                condition =  (name  == args.covmat)
            else:
                condition = (child_cnt == int_covmat)
                child_cnt = child_cnt + 1
            if condition:
                layer.register_forward_hook( report_covmat_hook)
                


    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        # train for one epoch
        train_results = trainer.train(train_data.get_loader(),
                                      chunk_batch=args.chunk_batch)

        # evaluate on validation set
        val_results = trainer.validate(val_data.get_loader())

        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)

        if args.drop_optim_state:
            optim_state_dict = None
        else:
            optim_state_dict = optimizer.state_dict()

        save_checkpoint({
            'epoch': epoch + 1,
            'model': args.model,
            'config': args.model_config,
            'state_dict': model.state_dict(),
            'optim_state_dict': optim_state_dict,
            'best_prec1': best_prec1
        }, is_best, path=save_path, save_all=args.save_all)

        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Prec@1 {train[prec1]:.3f} \t'
                     'Training Prec@5 {train[prec5]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Prec@1 {val[prec1]:.3f} \t'
                     'Validation Prec@5 {val[prec5]:.3f} \t\n'
                     .format(epoch + 1, train=train_results, val=val_results))

        if writer is not None:
            writer.add_scalar('Train/Loss', train_results['loss'], epoch)
            writer.add_scalar('Train/Prec@1', train_results['prec1'], epoch)
            writer.add_scalar('Train/Prec@5', train_results['prec5'], epoch)
            writer.add_scalar('Val/Loss', val_results['loss'], epoch)
            writer.add_scalar('Val/Prec@1', val_results['prec1'], epoch)
            writer.add_scalar('Val/Prec@5', val_results['prec5'], epoch)
            # tmplr = optimizer.get_lr()
            # writer.add_scalar('HyperParameters/learning-rate',  tmplr, epoch)

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)
        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})
        results.add(**values)

        results.plot(x='epoch', y=['training loss', 'validation loss'],
                     legend=['training', 'validation'],
                     title='Loss', ylabel='loss')
        results.plot(x='epoch', y=['training error1', 'validation error1'],
                     legend=['training', 'validation'],
                     title='Error@1', ylabel='error %')
        results.plot(x='epoch', y=['training error5', 'validation error5'],
                     legend=['training', 'validation'],
                     title='Error@5', ylabel='error %')
        if 'grad' in train_results.keys():
            results.plot(x='epoch', y=['training grad'],
                         legend=['gradient L2 norm'],
                         title='Gradient Norm', ylabel='value')
        results.save()
    logging.info(f'\nBest Validation Accuracy (top1): {best_prec1}')

    if writer:
        writer.close()