Пример #1
0
def train(params, args, world_rank, local_rank):

    #logging info
    logging.info('rank {:d}, begin data loader init (local rank {:d})'.format(
        world_rank, local_rank))

    # set device
    device = torch.device("cuda:{}".format(local_rank))

    # data loader
    pipe = dl.DaliPipeline(params,
                           num_threads=params.num_data_workers,
                           device_id=device.index)
    pipe.build()
    train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'],
                                            params.Nsamples,
                                            auto_reset=True)
    logging.info('rank %d, data loader initialized' % world_rank)

    model = UNet.UNet(params)
    model.to(device)
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DDP(model, device_ids=[local_rank])

    # loss
    criterion = UNet.CosmoLoss(params.LAMBDA_2)

    # amp stuff
    if args.enable_amp:
        gscaler = amp.GradScaler()

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        nsteps = 0
        fw_time = 0.
        bw_time = 0.
        log_time = 0.

        model.train()
        step_time = time.time()
        #for i, data in enumerate(train_data_loader, 0):
        with torch.autograd.profiler.emit_nvtx():
            for data in train_data_loader:
                iters += 1

                #adjust_LR(optimizer, params, iters)
                inp = data[0]["inp"]
                tar = data[0]["tar"]

                if not args.io_only:
                    torch.cuda.nvtx.range_push("cosmo3D:forward")
                    # fw pass
                    fw_time -= time.time()
                    optimizer.zero_grad()
                    with amp.autocast(args.enable_amp):
                        gen = model(inp)
                        loss = criterion(gen, tar)
                    fw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                    # bw pass
                    torch.cuda.nvtx.range_push("cosmo3D:backward")
                    bw_time -= time.time()
                    if args.enable_amp:
                        gscaler.scale(loss).backward()
                        gscaler.step(optimizer)
                        gscaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                    bw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                nsteps += 1

            # epoch done
            dist.barrier()
            step_time = (time.time() - step_time) / float(nsteps)
            fw_time /= float(nsteps)
            bw_time /= float(nsteps)
            io_time = max([step_time - fw_time - bw_time, 0])
            iters_per_sec = 1. / step_time

            end = time.time()
            if world_rank == 0:
                logging.info('Time taken for epoch {} is {} sec'.format(
                    epoch + 1, end - start))
                logging.info(
                    'total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}'
                    .format(step_time, fw_time, bw_time, io_time,
                            iters_per_sec, log_time))

        ## Output training stats
        #model.eval()
        #if world_rank==0:
        #  log_start = time.time()
        #  gens = []
        #  tars = []
        #  with torch.no_grad():
        #    for i, data in enumerate(train_data_loader, 0):
        #      if i>=16:
        #        break
        #      #inp, tar = map(lambda x: x.to(device), data)
        #      inp, tar = data
        #      gen = model(inp)
        #      gens.append(gen.detach().cpu().numpy())
        #      tars.append(tar.detach().cpu().numpy())
        #  gens = np.concatenate(gens, axis=0)
        #  tars = np.concatenate(tars, axis=0)
        #
        #  # Scalars
        #  args.tboard_writer.add_scalar('G_loss', loss.item(), iters)
        #
        #  # Plots
        #  fig, chi, L1score = meanL1(gens, tars)
        #  args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
        #  args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
        #  args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
        #  args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
        #  args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
        #  args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
        #  args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
        #
        #  fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
        #  args.tboard_writer.add_figure('genimg', fig, iters, close=True)
        #  log_end = time.time()
        #  log_time += log_end - log_start

        #  # Save checkpoint
        #  torch.save({'iters': iters, 'epoch':epoch, 'model_state': model.state_dict(),
        #              'optimizer_state_dict': optimizer.state_dict()}, params.checkpoint_path)

        #end = time.time()
        #if world_rank==0:
        #    logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, end-start))
        #    logging.info('total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}'
        #                 .format(step_time, fw_time, bw_time, io_time, iters_per_sec, log_time))

        # finalize
        dist.barrier()
Пример #2
0
def train(params, args, world_rank, local_rank):

    #logging info
    logging.info('rank {:d}, begin data loader init (local rank {:d})'.format(
        world_rank, local_rank))

    # set device
    device = torch.device("cuda:{}".format(local_rank))

    # data loader
    pipe = dl.DaliPipeline(params,
                           num_threads=params.num_data_workers,
                           device_id=device.index)
    pipe.build()
    train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'],
                                            params.Nsamples,
                                            auto_reset=True)
    logging.info('rank %d, data loader initialized' % world_rank)

    model = UNet.UNet(params).to(device)

    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DDP(model,
                    device_ids=[device.index],
                    output_device=device.index)

    # loss
    criterion = UNet.CosmoLoss(params.LAMBDA_2)

    # amp stuff
    if args.enable_amp:
        gscaler = amp.GradScaler()

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    with torch.autograd.profiler.emit_nvtx():
        for epoch in range(startEpoch, startEpoch + params.num_epochs):

            if args.global_timing:
                dist.barrier()

            start = time.time()
            epoch_step = 0
            tr_time = 0.
            fw_time = 0.
            bw_time = 0.
            log_time = 0.

            model.train()
            for data in train_data_loader:
                torch.cuda.nvtx.range_push("cosmo3D:step {}".format(iters))
                tr_start = time.time()
                adjust_LR(optimizer, params, iters)

                # fetch data
                inp = data[0]["inp"]
                tar = data[0]["tar"]

                if not args.io_only:
                    torch.cuda.nvtx.range_push(
                        "cosmo3D:forward {}".format(iters))
                    # fw pass
                    fw_time -= time.time()
                    optimizer.zero_grad()
                    with amp.autocast(args.enable_amp):
                        gen = model(inp)
                        loss = criterion(gen, tar)
                    fw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                    # bw pass
                    torch.cuda.nvtx.range_push(
                        "cosmo3D:backward {}".format(iters))
                    bw_time -= time.time()
                    if args.enable_amp:
                        gscaler.scale(loss).backward()
                        gscaler.step(optimizer)
                        gscaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                    bw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                iters += 1
                epoch_step += 1

                # step done
                tr_end = time.time()
                tr_time += tr_end - tr_start
                torch.cuda.nvtx.range_pop()

            # epoch done
            if args.global_timing:
                dist.barrier()

            end = time.time()
            epoch_time = end - start
            step_time = epoch_time / float(epoch_step)
            tr_time /= float(epoch_step)
            fw_time /= float(epoch_step)
            bw_time /= float(epoch_step)
            io_time = max([step_time - fw_time - bw_time, 0])
            iters_per_sec = 1. / step_time
            fw_per_sec = 1. / tr_time

            if world_rank == 0:
                logging.info('Time taken for epoch {} is {} sec'.format(
                    epoch + 1, epoch_time))
                logging.info(
                    'train step time = {} ({} steps), logging time = {}'.
                    format(tr_time, epoch_step, log_time))
                logging.info('train samples/sec = {} fw steps/sec = {}'.format(
                    iters_per_sec, fw_per_sec))