Пример #1
0
def validate(val_loader, model, criterion):
    """ Evaluate model using criterion on validation set """

    losses = Meter(ptag='Loss')
    acc = Meter(ptag='Accuracy')
    # top1 = Meter(ptag='Prec@1')
    # top5 = Meter(ptag='Prec@5')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (features, target) in enumerate(val_loader):
            target = target.cuda(non_blocking=True)
            # create one-hot vector from target
            kl_target = torch.zeros(target.shape[0], 1000,
                                    device='cuda').scatter_(
                                        1, target.view(-1, 1), 1)

            # compute output
            output = model(features)
            loss = criterion(output, kl_target)

            # measure accuracy and record loss
            # prec1, prec5 = accuracy(output, target, topk=())
            acc_val = accuracy(output, target)
            losses.update(loss.item(), features.size(0))
            acc.update(acc_val, features.size(0))
            # top1.update(prec1.item(), features.size(0))
            # top5.update(prec5.item(), features.size(0))

        log.info(' * Accuracy {acc.avg:.3f}'.format(acc=acc))

    return acc.avg
Пример #2
0
    def _setup_misc(self):
        # misc setup components that were in goissip_sgd
        config = self.config
        state = {}
        update_state(
            state, {
                'epoch': 0,
                'itr': 0,
                'best_prec1': 0,
                'is_best': True,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'elapsed_time': 0,
                'batch_meter': Meter(ptag='Time').__dict__,
                'data_meter': Meter(ptag='Data').__dict__,
                'nn_meter': Meter(ptag='Forward/Backward').__dict__
            })
        self.state = state

        # module used to relaunch jobs and handle external termination signals
        ClusterManager.set_checkpoint_dir(config['checkpoint_dir'])
        self.cmanager = ClusterManager(rank=config['rank'],
                                       world_size=config['world_size'],
                                       model_tag=config['tag'],
                                       state=state,
                                       all_workers=config['checkpoint_all'])

        # enable low-level optimization of compute graph using cuDNN library?
        cudnn.benchmark = True

        self.batch_meter = Meter(state['batch_meter'])
        self.data_meter = Meter(state['data_meter'])
        self.nn_meter = Meter(state['nn_meter'])

        # initalize log file
        if not os.path.exists(config['out_fname']):
            with open(config['out_fname'], 'w') as f:
                print('BEGIN-TRAINING\n'
                      'World-Size,{ws}\n'
                      'Num-DLWorkers,{nw}\n'
                      'Batch-Size,{bs}\n'
                      'Epoch,itr,BT(s),avg:BT(s),std:BT(s),'
                      'NT(s),avg:NT(s),std:NT(s),'
                      'DT(s),avg:DT(s),std:DT(s),'
                      'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'.
                      format(ws=config['world_size'],
                             nw=config['num_dataloader_workers'],
                             bs=config['batch_size']),
                      file=f)

        self.start_itr = state['itr']
        self.start_epoch = state['epoch']
        self.elapsed_time = state['elapsed_time']
        self.begin_time = time.time() - state['elapsed_time']
        self.best_val_prec1 = 0
Пример #3
0
def train(model, criterion, optimizer, batch_meter, data_meter, nn_meter,
          loader, epoch, itr, begin_time, num_itr_ignore):

    losses = Meter(ptag='Loss')
    acc = Meter(ptag="Accuracy")

    # top1 = Meter(ptag='Prec@1')
    # top5 = Meter(ptag='Prec@5')

    # switch to train mode
    model.train()

    # spoof sampler to continue from checkpoint w/o loading data all over again
    _train_loader = loader.__iter__()
    for i in range(itr):
        try:
            next(_train_loader.sample_iter)
        except Exception:
            # finished epoch but prempted before state was updated
            log.info('Loader spoof error attempt {}/{}'.format(i, len(loader)))
            return

    log.debug('Training (epoch {})'.format(epoch))

    batch_time = time.time()
    for i, (batch, target) in enumerate(_train_loader, start=itr):
        target = target.cuda(non_blocking=True)

        # create one-hot vector from target
        # kl_target = torch.zeros(target.shape[0], 1000, device='cuda').scatter_(
        #     1, target.view(-1, 1), 1)

        if num_itr_ignore == 0:
            data_meter.update(time.time() - batch_time)

        # ----------------------------------------------------------- #
        # Forward/Backward pass
        # ----------------------------------------------------------- #
        nn_time = time.time()
        output = model(batch)
        loss = criterion(output, target)
        loss.backward()

        if i % 100 == 0:
            update_learning_rate(optimizer,
                                 epoch,
                                 itr=i,
                                 itr_per_epoch=len(loader))
        optimizer.step()  # optimization update
        optimizer.zero_grad()
        if not args.overlap and not args.all_reduce:
            log.debug('Transferring params')
            model.transfer_params()
        if num_itr_ignore == 0:
            nn_meter.update(time.time() - nn_time)
        # ----------------------------------------------------------- #

        if num_itr_ignore == 0:
            batch_meter.update(time.time() - batch_time)
        batch_time = time.time()

        log_time = time.time()
        # measure accuracy and record loss
        acc_val = accuracy(output, target)

        losses.update(loss.item(), batch.size(0))
        acc.update(acc_val, batch.size(0))

        # top1.update(prec1.item(), batch.size(0))
        # top5.update(prec5.item(), batch.size(0))
        if i % args.print_freq == 0:
            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},{bt},{nt},{dt},'
                      '{loss.val:.4f},{loss.avg:.4f},'
                      '{acc.val:.3f},{acc.avg:.3f},'
                      '-1'.format(ep=epoch,
                                  itr=i,
                                  bt=batch_meter,
                                  dt=data_meter,
                                  nt=nn_meter,
                                  loss=losses,
                                  acc=acc),
                      file=f)
        if num_itr_ignore > 0:
            num_itr_ignore -= 1
        log_time = time.time() - log_time
        log.debug(log_time)

        if (args.num_iterations_per_training_epoch != -1
                and i + 1 == args.num_iterations_per_training_epoch):
            break

    with open(args.out_fname, '+a') as f:
        print('{ep},{itr},{bt},{nt},{dt},'
              '{loss.val:.4f},{loss.avg:.4f},'
              '{acc.val:.3f},{acc.avg:.3f},'
              '-1'.format(ep=epoch,
                          itr=i,
                          bt=batch_meter,
                          dt=data_meter,
                          nt=nn_meter,
                          loss=losses,
                          acc=acc),
              file=f)
Пример #4
0
def main():

    global args, state, log
    args = parse_args()
    print("Successfully parsed the args")

    log = make_logger(args.rank, args.verbose)
    log.info('args: {}'.format(args))
    log.info(socket.gethostname())

    # seed for reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # init model, loss, and optimizer
    model = init_model()
    print("Model has been initialized")

    if args.all_reduce:
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        model = GossipDataParallel(model,
                                   graph=args.graph,
                                   mixing=args.mixing,
                                   comm_device=args.comm_device,
                                   push_sum=args.push_sum,
                                   overlap=args.overlap,
                                   synch_freq=args.synch_freq,
                                   verbose=args.verbose,
                                   use_streams=not args.no_cuda_streams)

    core_criterion = nn.CrossEntropyLoss(
    )  #nn.KLDivLoss(reduction='batchmean').cuda()
    log_softmax = nn.LogSoftmax(dim=1)

    def criterion(input, kl_target):
        assert kl_target.dtype != torch.int64
        loss = core_criterion(log_softmax(input), kl_target)
        return loss

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    optimizer.zero_grad()

    # dictionary used to encode training state
    state = {}
    update_state(
        state, {
            'epoch': 0,
            'itr': 0,
            'best_prec1': 0,
            'is_best': True,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'elapsed_time': 0,
            'batch_meter': Meter(ptag='Time').__dict__,
            'data_meter': Meter(ptag='Data').__dict__,
            'nn_meter': Meter(ptag='Forward/Backward').__dict__
        })

    # module used to relaunch jobs and handle external termination signals
    cmanager = ClusterManager(rank=args.rank,
                              world_size=args.world_size,
                              model_tag=args.tag,
                              state=state,
                              all_workers=args.checkpoint_all)

    # resume from checkpoint
    if args.resume:
        if os.path.isfile(cmanager.checkpoint_fpath):
            log.info("=> loading checkpoint '{}'".format(
                cmanager.checkpoint_fpath))
            checkpoint = torch.load(cmanager.checkpoint_fpath)
            update_state(
                state, {
                    'epoch': checkpoint['epoch'],
                    'itr': checkpoint['itr'],
                    'best_prec1': checkpoint['best_prec1'],
                    'is_best': False,
                    'state_dict': checkpoint['state_dict'],
                    'optimizer': checkpoint['optimizer'],
                    'elapsed_time': checkpoint['elapsed_time'],
                    'batch_meter': checkpoint['batch_meter'],
                    'data_meter': checkpoint['data_meter'],
                    'nn_meter': checkpoint['nn_meter']
                })
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {}; itr {})".format(
                cmanager.checkpoint_fpath, checkpoint['epoch'],
                checkpoint['itr']))
        else:
            log.info("=> no checkpoint found at '{}'".format(
                cmanager.checkpoint_fpath))

    # enable low-level optimization of compute graph using cuDNN library?
    cudnn.benchmark = True

    # meters used to compute timing stats
    batch_meter = Meter(state['batch_meter'])
    data_meter = Meter(state['data_meter'])
    nn_meter = Meter(state['nn_meter'])

    # initalize log file
    if not os.path.exists(args.out_fname):
        with open(args.out_fname, 'w') as f:
            print('BEGIN-TRAINING\n'
                  'World-Size,{ws}\n'
                  'Num-DLWorkers,{nw}\n'
                  'Batch-Size,{bs}\n'
                  'Epoch,itr,BT(s),avg:BT(s),std:BT(s),'
                  'NT(s),avg:NT(s),std:NT(s),'
                  'DT(s),avg:DT(s),std:DT(s),'
                  'Loss,avg:Loss,Accuracy,avg:Accuracy,val'.format(
                      ws=args.world_size,
                      nw=args.num_dataloader_workers,
                      bs=args.batch_size),
                  file=f)

    # create distributed data loaders
    loader, sampler = make_dataloader(args, train=True)
    if not args.train_fast:
        val_loader = make_dataloader(args, train=False)

    start_itr = state['itr']
    start_epoch = state['epoch']
    elapsed_time = state['elapsed_time']
    begin_time = time.time() - state['elapsed_time']
    best_val_prec1 = 0
    for epoch in range(start_epoch, args.num_epochs):

        # deterministic seed used to load agent's subset of data
        sampler.set_epoch(epoch + args.seed * 90)

        if not args.all_reduce:
            # update the model's peers_per_itr attribute
            update_peers_per_itr(model, epoch)

        # start all agents' training loop at same time
        if not args.all_reduce:
            model.block()

        train(model, criterion, optimizer, batch_meter, data_meter, nn_meter,
              loader, epoch, start_itr, begin_time, args.num_itr_ignore)

        start_itr = 0
        if not args.train_fast:
            # update state after each epoch
            elapsed_time = time.time() - begin_time
            update_state(
                state, {
                    'epoch': epoch + 1,
                    'itr': start_itr,
                    'is_best': False,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'elapsed_time': elapsed_time,
                    'batch_meter': batch_meter.__dict__,
                    'data_meter': data_meter.__dict__,
                    'nn_meter': nn_meter.__dict__
                })
            # evaluate on validation set and save checkpoint
            prec1 = validate(val_loader, model, criterion)
            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},{bt},{nt},{dt},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{val}'.format(ep=epoch,
                                     itr=-1,
                                     bt=batch_meter,
                                     dt=data_meter,
                                     nt=nn_meter,
                                     filler=-1,
                                     val=prec1),
                      file=f)

            if prec1 > best_val_prec1:
                update_state(state, {'is_best': True})
                best_val_prec1 = prec1

            epoch_id = epoch if not args.overwrite_checkpoints else None

            cmanager.save_checkpoint(
                epoch_id, requeue_on_signal=(epoch != args.num_epochs - 1))

    if args.train_fast:
        val_loader = make_dataloader(args, train=False)
        acc = validate(val_loader, model, criterion)
        log.info('Test accuracy: {}'.format(acc))

    log.info('elapsed_time {0}'.format(elapsed_time))
Пример #5
0
def validate(val_loader, model, criterion):
    """ Evaluate model using criterion on validation set """

    losses = Meter(ptag='Loss')
    top1 = Meter(ptag='Prec@1')
    top5 = Meter(ptag='Prec@5')

    # switch to evaluate mode
    model.eval()

    model.disable_gossip()

    with torch.no_grad():
        for i, (features, target) in enumerate(val_loader):

            target = target.cuda(non_blocking=True)

            # compute output
            output = model(features)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), features.size(0))
            top1.update(prec1.item(), features.size(0))
            top5.update(prec5.item(), features.size(0))

        log.info(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))

    return top1.avg
Пример #6
0
def train(model, criterion, optimizer, batch_meter, data_meter, nn_meter,
          loader, epoch, itr, begin_time):

    losses = Meter(ptag='Loss')
    top1 = Meter(ptag='Prec@1')
    top5 = Meter(ptag='Prec@5')

    # switch to train mode
    model.train()

    # spoof sampler to continue from checkpoint w/o loading data all over again
    _train_loader = loader.__iter__()
    for i in range(itr):
        try:
            next(_train_loader.sample_iter)
        except Exception:
            # finished epoch but prempted before state was updated
            log.info('Loader spoof error attempt {}/{}'.format(i, len(loader)))
            return

    log.debug('Training (epoch {})'.format(epoch))

    model.enable_gossip()

    batch_time = time.time()
    for i, (batch, target) in enumerate(_train_loader, start=itr):

        target = target.cuda(non_blocking=True)
        data_meter.update(time.time() - batch_time)

        # ----------------------------------------------------------- #
        # Forward/Backward pass
        # ----------------------------------------------------------- #
        nn_time = time.time()
        output = model(batch)
        loss = criterion(output, target)

        bilat_freq = 100
        if i == 0:
            update_global_iteration_counter(itr=1, itr_per_epoch=len(loader))
            update_bilat_learning_rate(model, itr_per_epoch=len(loader))
        elif (i + args.rank) % (bilat_freq) == 0:
            update_global_iteration_counter(itr=bilat_freq,
                                            itr_per_epoch=len(loader))
            update_bilat_learning_rate(model, itr_per_epoch=len(loader))

        loss.backward()
        update_learning_rate(optimizer,
                             epoch,
                             itr=i,
                             itr_per_epoch=len(loader))
        optimizer.step()  # optimization update
        optimizer.zero_grad()
        nn_meter.update(time.time() - nn_time)
        # ----------------------------------------------------------- #

        batch_meter.update(time.time() - batch_time)
        batch_time = time.time()

        log_time = time.time()
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), batch.size(0))
        top1.update(prec1.item(), batch.size(0))
        top5.update(prec5.item(), batch.size(0))
        if i % args.print_freq == 0:
            ep = args.global_epoch
            itr = args.global_itr % (len(loader) * args.world_size)
            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},{bt},{nt},{dt},'
                      '{loss.val:.4f},{loss.avg:.4f},'
                      '{top1.val:.3f},{top1.avg:.3f},'
                      '{top5.val:.3f},{top5.avg:.3f},-1'.format(ep=ep,
                                                                itr=itr,
                                                                bt=batch_meter,
                                                                dt=data_meter,
                                                                nt=nn_meter,
                                                                loss=losses,
                                                                top1=top1,
                                                                top5=top5),
                      file=f)
        log_time = time.time() - log_time
        log.debug(log_time)

    with open(args.out_fname, '+a') as f:
        print('{ep},{itr},{bt},{nt},{dt},'
              '{loss.val:.4f},{loss.avg:.4f},'
              '{top1.val:.3f},{top1.avg:.3f},'
              '{top5.val:.3f},{top5.avg:.3f},-1'.format(ep=epoch,
                                                        itr=i,
                                                        bt=batch_meter,
                                                        dt=data_meter,
                                                        nt=nn_meter,
                                                        loss=losses,
                                                        top1=top1,
                                                        top5=top5),
              file=f)
Пример #7
0
def main():

    global args, state, log
    args = parse_args()

    log = make_logger(args.rank, args.verbose)
    log.info('args: {}'.format(args))
    log.info(socket.gethostname())

    # seed for reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # init model, loss, and optimizer
    model = init_model()

    assert args.bilat and not args.all_reduce
    model = BilatGossipDataParallel(
        model,
        master_addr=args.master_addr,
        master_port=args.master_port,
        backend=args.backend,
        world_size=args.world_size,
        rank=args.rank,
        graph_class=args.graph_class,
        mixing_class=args.mixing_class,
        comm_device=args.comm_device,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=args.nesterov,
        verbose=args.verbose,
        num_peers=args.ppi_schedule[0],
        network_interface_type=args.network_interface_type)

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    optimizer.zero_grad()

    # dictionary used to encode training state
    state = {}
    update_state(
        state, {
            'epoch': 0,
            'itr': 0,
            'best_prec1': 0,
            'is_best': True,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'elapsed_time': 0,
            'batch_meter': Meter(ptag='Time').__dict__,
            'data_meter': Meter(ptag='Data').__dict__,
            'nn_meter': Meter(ptag='Forward/Backward').__dict__
        })

    # module used to relaunch jobs and handle external termination signals
    cmanager = ClusterManager(rank=args.rank,
                              world_size=args.world_size,
                              model_tag=args.tag,
                              state=state,
                              all_workers=args.checkpoint_all)

    # resume from checkpoint
    if args.resume:
        if os.path.isfile(cmanager.checkpoint_fpath):
            log.info("=> loading checkpoint '{}'".format(
                cmanager.checkpoint_fpath))
            checkpoint = torch.load(cmanager.checkpoint_fpath)
            update_state(
                state, {
                    'epoch': checkpoint['epoch'],
                    'itr': checkpoint['itr'],
                    'best_prec1': checkpoint['best_prec1'],
                    'is_best': False,
                    'state_dict': checkpoint['state_dict'],
                    'optimizer': checkpoint['optimizer'],
                    'elapsed_time': checkpoint['elapsed_time'],
                    'batch_meter': checkpoint['batch_meter'],
                    'data_meter': checkpoint['data_meter'],
                    'nn_meter': checkpoint['nn_meter']
                })
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {}; itr {})".format(
                cmanager.checkpoint_fpath, checkpoint['epoch'],
                checkpoint['itr']))
        else:
            log.info("=> no checkpoint found at '{}'".format(
                cmanager.checkpoint_fpath))

    # enable low-level optimization of compute graph using cuDNN library?
    cudnn.benchmark = True

    # meters used to compute timing stats
    batch_meter = Meter(state['batch_meter'])
    data_meter = Meter(state['data_meter'])
    nn_meter = Meter(state['nn_meter'])

    # initalize log file
    if not args.resume:
        with open(args.out_fname, 'w') as f:
            print(
                'BEGIN-TRAINING\n'
                'World-Size,{ws}\n'
                'Num-DLWorkers,{nw}\n'
                'Batch-Size,{bs}\n'
                'Epoch,itr,BT(s),avg:BT(s),std:BT(s),'
                'NT(s),avg:NT(s),std:NT(s),'
                'DT(s),avg:DT(s),std:DT(s),'
                'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'.format(
                    ws=args.world_size,
                    nw=args.num_dataloader_workers,
                    bs=args.batch_size),
                file=f)

    # create distributed data loaders
    loader, sampler = make_dataloader(args, train=True)
    if not args.train_fast:
        val_loader = make_dataloader(args, train=False)

    # start all agents' training loop at same time
    model.block()
    start_itr = state['itr']
    start_epoch = state['epoch']
    elapsed_time = state['elapsed_time']
    begin_time = time.time() - state['elapsed_time']
    epoch = start_epoch
    stopping_criterion = epoch >= args.num_epochs
    while not stopping_criterion:

        # deterministic seed used to load agent's subset of data
        sampler.set_epoch(epoch + args.seed * 90)

        train(model, criterion, optimizer, batch_meter, data_meter, nn_meter,
              loader, epoch, start_itr, begin_time)

        start_itr = 0
        if not args.train_fast:
            # update state after each epoch
            elapsed_time = time.time() - begin_time
            update_state(
                state, {
                    'epoch': epoch + 1,
                    'itr': start_itr,
                    'is_best': False,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'elapsed_time': elapsed_time,
                    'batch_meter': batch_meter.__dict__,
                    'data_meter': data_meter.__dict__,
                    'nn_meter': nn_meter.__dict__
                })
            # evaluate on validation set and save checkpoint
            prec1 = validate(val_loader, model, criterion)
            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},{bt},{nt},{dt},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{val}'.format(ep=epoch,
                                     itr=-1,
                                     bt=batch_meter,
                                     dt=data_meter,
                                     nt=nn_meter,
                                     filler=-1,
                                     val=prec1),
                      file=f)
            cmanager.save_checkpoint()
            # sycnhronize models at the end of validation run
            model.block()

        epoch += 1
        stopping_criterion = args.global_epoch >= args.num_epochs

    if args.train_fast:
        val_loader = make_dataloader(args, train=False)
        prec1 = validate(val_loader, model, criterion)
        log.info('Test accuracy: {}'.format(prec1))

    log.info('elapsed_time {0}'.format(elapsed_time))
def validate(val_loader, model, criterion):
    """ Evaluate model using criterion on validation set """

    losses = Meter(ptag='Loss')
    top1 = Meter(ptag='Prec@1')
    top5 = Meter(ptag='Prec@5')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (features, target) in enumerate(val_loader):

            # if args.fp16:
            #     features = features.cuda(non_blocking=True).half()
            # This is not needed but let it be since there is no harm

            target = target.cuda(non_blocking=True)
            # create one-hot vector from target
            kl_target = torch.zeros(target.shape[0], 1000,
                                    device='cuda').scatter_(
                                        1, target.view(-1, 1), 1)

            # compute output
            output = model(features)
            loss = criterion(output, kl_target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), features.size(0))
            top1.update(prec1.item(), features.size(0))
            top5.update(prec5.item(), features.size(0))

        # log.info(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
        #          .format(top1=top1, top5=top5))
        log.info(
            ' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {losses.avg:.3f}'
            .format(top1=top1, top5=top5, losses=losses))

    return top1.avg
Пример #9
0
def main():

    global args, state, log
    args = parse_args()

    log = make_logger(args.rank, args.verbose)
    log.info('args: {}'.format(args))
    log.info(socket.gethostname())

    # seed for reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if args.distributed:
        # initialize torch distributed backend
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = args.master_port
        dist.init_process_group(backend=args.backend,
                                world_size=args.world_size,
                                rank=args.rank)

    # init model, loss, and optimizer
    model = init_model()
    if args.all_reduce:
        model = AllReduceDataParallel(model,
                                      distributed=args.distributed,
                                      comm_device=args.comm_device,
                                      verbose=args.verbose)
    else:
        if args.single_threaded:
            model = SimpleGossipDataParallel(model,
                                             distributed=args.distributed,
                                             graph=args.graph,
                                             comm_device=args.comm_device,
                                             push_sum=args.push_sum,
                                             verbose=args.verbose)
        else:
            model = GossipDataParallel(model,
                                       distributed=args.distributed,
                                       graph=args.graph,
                                       mixing=args.mixing,
                                       comm_device=args.comm_device,
                                       push_sum=args.push_sum,
                                       overlap=args.overlap,
                                       synch_freq=args.synch_freq,
                                       verbose=args.verbose)
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    optimizer.zero_grad()

    # dictionary used to encode training state
    state = {}
    update_state(state, {
            'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'elapsed_time': 0,
            'batch_meter': Meter(ptag='Time').__dict__,
            'data_meter': Meter(ptag='Data').__dict__,
            'nn_meter': Meter(ptag='Forward/Backward').__dict__
    })

    # module used to relaunch jobs and handle external termination signals
    cmanager = ClusterManager(rank=args.rank,
                              world_size=args.world_size,
                              bs_fname=args.bs_fpath,
                              model_tag=args.tag,
                              state=state,
                              all_workers=args.checkpoint_all)

    # resume from checkpoint
    if args.resume:
        f_fpath = cmanager.checkpoint_fpath
        if os.path.isfile(f_fpath):
            log.info("=> loading checkpoint '{}'"
                     .format(f_fpath))
            checkpoint = torch.load(f_fpath)
            update_state(state, {
                          'epoch': checkpoint['epoch'],
                          'itr': checkpoint['itr'],
                          'best_prec1': checkpoint['best_prec1'],
                          'is_best': False,
                          'state_dict': checkpoint['state_dict'],
                          'optimizer': checkpoint['optimizer'],
                          'elapsed_time': checkpoint['elapsed_time'],
                          'batch_meter': checkpoint['batch_meter'],
                          'data_meter': checkpoint['data_meter'],
                          'nn_meter': checkpoint['nn_meter']
            })
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {}; itr {})"
                     .format(f_fpath,
                             checkpoint['epoch'], checkpoint['itr']))

            # synch models that are loaded
            if not args.overlap or args.all_reduce:
                model.transfer_params()

        else:
            log.info("=> no checkpoint found at '{}'"
                     .format(cmanager.checkpoint_fpath))

    # enable low-level optimization of compute graph using cuDNN library?
    cudnn.benchmark = True

    # meters used to compute timing stats
    batch_meter = Meter(state['batch_meter'])
    data_meter = Meter(state['data_meter'])
    nn_meter = Meter(state['nn_meter'])

    # initalize log file
    if not args.resume:
        with open(args.out_fname, 'w') as f:
            print('BEGIN-TRAINING\n'
                  'World-Size,{ws}\n'
                  'Num-DLWorkers,{nw}\n'
                  'Batch-Size,{bs}\n'
                  'Epoch,itr,BT(s),avg:BT(s),std:BT(s),'
                  'NT(s),avg:NT(s),std:NT(s),'
                  'DT(s),avg:DT(s),std:DT(s),'
                  'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'
                  .format(ws=args.world_size,
                          nw=args.num_dataloader_workers,
                          bs=args.batch_size), file=f)

    # create distributed data loaders
    loader, sampler = make_dataloader(args, train=True)
    if not args.train_fast:
        val_loader = make_dataloader(args, train=False)

    start_itr = state['itr']
    start_epoch = state['epoch']
    elapsed_time = state['elapsed_time']
    begin_time = time.time() - state['elapsed_time']
    for epoch in range(start_epoch, args.num_epochs):

        # deterministic seed used to load agent's subset of data
        sampler.set_epoch(epoch + args.seed * 90)

        if not args.all_reduce:
            # update the model's peers_per_itr attribute
            update_peers_per_itr(model, epoch)

        # start all agents' training loop at same time
        model.block()
        train(model, criterion, optimizer,
              batch_meter, data_meter, nn_meter,
              loader, epoch, start_itr, begin_time)

        start_itr = 0
        if not args.train_fast:
            # update state after each epoch
            elapsed_time = time.time() - begin_time
            update_state(state, {
                'epoch': epoch + 1, 'itr': start_itr,
                'is_best': False,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'elapsed_time': elapsed_time,
                'batch_meter': batch_meter.__dict__,
                'data_meter': data_meter.__dict__,
                'nn_meter': nn_meter.__dict__
            })
            # evaluate on validation set and save checkpoint
            prec1 = validate(val_loader, model, criterion)
            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},{bt},{nt},{dt},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{val}'
                      .format(ep=epoch, itr=-1,
                              bt=batch_meter,
                              dt=data_meter, nt=nn_meter,
                              filler=-1, val=prec1), file=f)
            epoch_id = epoch if not args.overwrite_checkpoints else None
            cmanager.save_checkpoint(epoch_id)

    if args.train_fast:
        val_loader = make_dataloader(args, train=False)
        prec1 = validate(val_loader, model, criterion)
        log.info('Test accuracy: {}'.format(prec1))

    cmanager.halt = True

    log.info('elapsed_time {0}'.format(elapsed_time))
Пример #10
0
def train(config, model, criterion, optimizer, batch_meter, data_meter, nn_meter,
          loader, epoch, itr, begin_time, num_itr_ignore, log):

    losses = Meter(ptag='Loss')
    top1 = Meter(ptag='Prec@1')
    top5 = Meter(ptag='Prec@5')

    # switch to train mode
    model.train()

    # spoof sampler to continue from checkpoint w/o loading data all over again
    _train_loader = loader.__iter__()
    for i in range(itr):
        try:
            next(_train_loader.sample_iter)
        except Exception:
            # finished epoch but prempted before state was updated
            log.info('Loader spoof error attempt {}/{}'.format(i, len(loader)))
            return

    log.debug('Training (epoch {})'.format(epoch))

    batch_time = time.time()
    for i, (batch, target) in enumerate(_train_loader, start=itr):
        # if args.fp16:
        #     batch = batch.cuda(non_blocking=True).half()

        target = target.cuda(non_blocking=True)
        # create one-hot vector from target
        kl_target = torch.zeros(target.shape[0], 1000, device='cuda').scatter_(
            1, target.view(-1, 1), 1)

        if num_itr_ignore == 0:
            data_meter.update(time.time() - batch_time)

        # ----------------------------------------------------------- #
        # Forward/Backward pass
        # ----------------------------------------------------------- #
        nn_time = time.time()
        output = model(batch)
        loss = criterion(output, kl_target)

        # if args.fp16:
        #     if args.amp:
        #         with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
        #             scaled_loss.backward()
        #     else:
        #         optimizer.backward(loss)
        # else:
        #     loss.backward()

        loss.backward()

        if i % 100 == 0:
            update_learning_rate(config, optimizer, epoch, log, itr=i,
                                 itr_per_epoch=len(loader))
        optimizer.step()  # optimization update
        optimizer.zero_grad()
        if not config['overlap'] and not config['all_reduce']:
            log.debug('Transferring params')
            model.transfer_params()
        if num_itr_ignore == 0:
            nn_meter.update(time.time() - nn_time)
        # ----------------------------------------------------------- #

        if num_itr_ignore == 0:
            batch_meter.update(time.time() - batch_time)
        batch_time = time.time()

        log_time = time.time()
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), batch.size(0))
        top1.update(prec1.item(), batch.size(0))
        top5.update(prec5.item(), batch.size(0))
        if i % config['print_freq'] == 0:
            with open(config['out_fname'], '+a') as f:
                print('{ep},{itr},{bt},{nt},{dt},'
                      '{loss.val:.4f},{loss.avg:.4f},'
                      '{top1.val:.3f},{top1.avg:.3f},'
                      '{top5.val:.3f},{top5.avg:.3f},-1'
                      .format(ep=epoch, itr=i,
                              bt=batch_meter,
                              dt=data_meter, nt=nn_meter,
                              loss=losses, top1=top1,
                              top5=top5), file=f)
        if num_itr_ignore > 0:
            num_itr_ignore -= 1
        log_time = time.time() - log_time
        log.debug(log_time)

        if (config['num_iterations_per_training_epoch'] != -1 and
                i+1 == config['num_iterations_per_training_epoch']):
            break

    with open(config['out_fname'], '+a') as f:
        print('{ep},{itr},{bt},{nt},{dt},'
              '{loss.val:.4f},{loss.avg:.4f},'
              '{top1.val:.3f},{top1.avg:.3f},'
              '{top5.val:.3f},{top5.avg:.3f},-1'
              .format(ep=epoch, itr=i,
                      bt=batch_meter,
                      dt=data_meter, nt=nn_meter,
                      loss=losses, top1=top1,
                      top5=top5), file=f)

    return losses.avg, top1.avg, top5.avg