Пример #1
0
def main():

    global amp_handle, args, state, log
    args = parse_args()
    if args.fp16 and args.amp:
        amp_handle = amp.init()

    log = make_logger(args.rank, args.verbose)
    log.info('args: {}'.format(args))
    log.info(socket.gethostname())

    # seed for reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # init model, loss, and optimizer
    model = init_model()
    if args.all_reduce:
        if args.fp16 and args.apex_ddp:
            model = ApexDDP(model)
        else:
            model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        model = GossipDataParallel(model,
                                   graph=args.graph,
                                   mixing=args.mixing,
                                   comm_device=args.comm_device,
                                   push_sum=args.push_sum,
                                   overlap=args.overlap,
                                   synch_freq=args.synch_freq,
                                   verbose=args.verbose,
                                   use_streams=not args.no_cuda_streams)

    core_criterion = nn.KLDivLoss(reduction='batchmean').cuda()
    log_softmax = nn.LogSoftmax(dim=1)

    def criterion(input, kl_target):
        assert kl_target.dtype != torch.int64
        loss = core_criterion(log_softmax(input), kl_target)
        return loss

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    if args.fp16 and not args.amp:
        optimizer = FP16_Optimizer(optimizer,
                                   dynamic_loss_scale=True,
                                   verbose=False)
    optimizer.zero_grad()

    # dictionary used to encode training state
    state = {}
    update_state(
        state, {
            'epoch': 0,
            'itr': 0,
            'best_prec1': 0,
            'is_best': True,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'elapsed_time': 0,
            'batch_meter': Meter(ptag='Time').__dict__,
            'data_meter': Meter(ptag='Data').__dict__,
            'nn_meter': Meter(ptag='Forward/Backward').__dict__
        })

    # module used to relaunch jobs and handle external termination signals
    cmanager = ClusterManager(rank=args.rank,
                              world_size=args.world_size,
                              model_tag=args.tag,
                              state=state,
                              all_workers=args.checkpoint_all)

    # resume from checkpoint
    if args.resume:
        if os.path.isfile(cmanager.checkpoint_fpath):
            log.info("=> loading checkpoint '{}'".format(
                cmanager.checkpoint_fpath))
            checkpoint = torch.load(cmanager.checkpoint_fpath)
            update_state(
                state, {
                    'epoch': checkpoint['epoch'],
                    'itr': checkpoint['itr'],
                    'best_prec1': checkpoint['best_prec1'],
                    'is_best': False,
                    'state_dict': checkpoint['state_dict'],
                    'optimizer': checkpoint['optimizer'],
                    'elapsed_time': checkpoint['elapsed_time'],
                    'batch_meter': checkpoint['batch_meter'],
                    'data_meter': checkpoint['data_meter'],
                    'nn_meter': checkpoint['nn_meter']
                })
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {}; itr {})".format(
                cmanager.checkpoint_fpath, checkpoint['epoch'],
                checkpoint['itr']))
        else:
            log.info("=> no checkpoint found at '{}'".format(
                cmanager.checkpoint_fpath))

    # enable low-level optimization of compute graph using cuDNN library?
    cudnn.benchmark = True

    # meters used to compute timing stats
    batch_meter = Meter(state['batch_meter'])
    data_meter = Meter(state['data_meter'])
    nn_meter = Meter(state['nn_meter'])

    # initalize log file
    if not os.path.exists(args.out_fname):
        with open(args.out_fname, 'w') as f:
            print(
                'BEGIN-TRAINING\n'
                'World-Size,{ws}\n'
                'Num-DLWorkers,{nw}\n'
                'Batch-Size,{bs}\n'
                'Epoch,itr,BT(s),avg:BT(s),std:BT(s),'
                'NT(s),avg:NT(s),std:NT(s),'
                'DT(s),avg:DT(s),std:DT(s),'
                'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'.format(
                    ws=args.world_size,
                    nw=args.num_dataloader_workers,
                    bs=args.batch_size),
                file=f)

    # create distributed data loaders
    loader, sampler = make_dataloader(args, train=True)
    if not args.train_fast:
        val_loader = make_dataloader(args, train=False)

    start_itr = state['itr']
    start_epoch = state['epoch']
    elapsed_time = state['elapsed_time']
    begin_time = time.time() - state['elapsed_time']
    best_val_prec1 = 0
    for epoch in range(start_epoch, args.num_epochs):

        # deterministic seed used to load agent's subset of data
        sampler.set_epoch(epoch + args.seed * 90)

        if not args.all_reduce:
            # update the model's peers_per_itr attribute
            update_peers_per_itr(model, epoch)

        # start all agents' training loop at same time
        if not args.all_reduce:
            model.block()
        train(model, criterion, optimizer, batch_meter, data_meter, nn_meter,
              loader, epoch, start_itr, begin_time, args.num_itr_ignore)

        start_itr = 0
        if not args.train_fast:
            # update state after each epoch
            elapsed_time = time.time() - begin_time
            update_state(
                state, {
                    'epoch': epoch + 1,
                    'itr': start_itr,
                    'is_best': False,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'elapsed_time': elapsed_time,
                    'batch_meter': batch_meter.__dict__,
                    'data_meter': data_meter.__dict__,
                    'nn_meter': nn_meter.__dict__
                })
            # evaluate on validation set and save checkpoint
            prec1 = validate(val_loader, model, criterion)
            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},{bt},{nt},{dt},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{filler},{filler},'
                      '{val}'.format(ep=epoch,
                                     itr=-1,
                                     bt=batch_meter,
                                     dt=data_meter,
                                     nt=nn_meter,
                                     filler=-1,
                                     val=prec1),
                      file=f)

            if prec1 > best_val_prec1:
                update_state(state, {'is_best': True})
                best_val_prec1 = prec1

            epoch_id = epoch if not args.overwrite_checkpoints else None

            cmanager.save_checkpoint(
                epoch_id, requeue_on_signal=(epoch != args.num_epochs - 1))

    if args.train_fast:
        val_loader = make_dataloader(args, train=False)
        prec1 = validate(val_loader, model, criterion)
        log.info('Test accuracy: {}'.format(prec1))

    log.info('elapsed_time {0}'.format(elapsed_time))