Esempio n. 1
0
def train(params, args, world_rank, local_rank):

    #logging info
    logging.info('rank {:d}, begin data loader init (local rank {:d})'.format(
        world_rank, local_rank))
    #train_data_loader = get_data_loader_distributed(params, world_rank)
    train_data_loader = get_data_loader_distributed(params, world_rank,
                                                    local_rank)
    logging.info('rank %d, data loader initialized' % world_rank)

    # set device
    device = torch.device("cuda:{}".format(local_rank))

    model = UNet.UNet(params)
    model.to(device)
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DDP(model, device_ids=[local_rank])

    # amp stuff
    #gscaler = amp.GradScaler()

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        nsteps = 0
        fw_time = 0.
        bw_time = 0.
        log_time = 0.

        model.train()
        step_time = time.time()
        for i, data in enumerate(train_data_loader, 0):
            iters += 1

            #adjust_LR(optimizer, params, iters)
            inp, tar = map(lambda x: x.to(device), data)

            if not args.io_only:

                # fw pass
                fw_time -= time.time()
                optimizer.zero_grad()
                #with amp.autocast():
                gen = model(inp)
                loss = UNet.loss_func(gen, tar, params)
                fw_time += time.time()

                # bw pass
                bw_time -= time.time()
                loss.backward()
                optimizer.step()
                #gscaler.scale(loss).backward()
                #gscaler.step(optimizer)
                #gscaler.update()
                bw_time += time.time()

            nsteps += 1

        # epoch done
        dist.barrier()
        step_time = (time.time() - step_time) / float(nsteps)
        fw_time /= float(nsteps)
        bw_time /= float(nsteps)
        io_time = max([step_time - fw_time - bw_time, 0])
        iters_per_sec = 1. / step_time

        ## Output training stats
        #model.eval()
        #if world_rank==0:
        #  log_start = time.time()
        #  gens = []
        #  tars = []
        #  with torch.no_grad():
        #    for i, data in enumerate(train_data_loader, 0):
        #      if i>=16:
        #        break
        #      #inp, tar = map(lambda x: x.to(device), data)
        #      inp, tar = data
        #      gen = model(inp)
        #      gens.append(gen.detach().cpu().numpy())
        #      tars.append(tar.detach().cpu().numpy())
        #  gens = np.concatenate(gens, axis=0)
        #  tars = np.concatenate(tars, axis=0)
        #
        #  # Scalars
        #  args.tboard_writer.add_scalar('G_loss', loss.item(), iters)
        #
        #  # Plots
        #  fig, chi, L1score = meanL1(gens, tars)
        #  args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
        #  args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
        #  args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
        #  args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
        #  args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
        #  args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
        #  args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
        #
        #  fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
        #  args.tboard_writer.add_figure('genimg', fig, iters, close=True)
        #  log_end = time.time()
        #  log_time += log_end - log_start

        #  # Save checkpoint
        #  torch.save({'iters': iters, 'epoch':epoch, 'model_state': model.state_dict(),
        #              'optimizer_state_dict': optimizer.state_dict()}, params.checkpoint_path)

        end = time.time()
        if world_rank == 0:
            logging.info('Time taken for epoch {} is {} sec'.format(
                epoch + 1, end - start))
            logging.info(
                'total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}'
                .format(step_time, fw_time, bw_time, io_time, iters_per_sec,
                        log_time))

        # finalize
        dist.barrier()
Esempio n. 2
0
def train(params, args, world_rank):
    logging.info('rank %d, begin data loader init' % world_rank)
    train_data_loader = get_data_loader_distributed(params, world_rank)
    test_data_loader = get_data_loader_distributed_test(params, world_rank)
    logging.info('rank %d, data loader initialized' % world_rank)
    model = UNet.UNet(params).cuda()
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DistributedDataParallel(model)

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path,
                                map_location='cuda:{}'.format(args.local_rank))
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    device = torch.cuda.current_device()
    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        tr_time = 0.
        log_time = 0.

        for i, data in enumerate(train_data_loader, 0):
            iters += 1
            adjust_LR(optimizer, params, iters)
            inp, tar = map(lambda x: x.to(device), data)
            tr_start = time.time()
            b_size = inp.size(0)

            model.zero_grad()
            gen = model(inp)
            loss = UNet.loss_func(gen, tar, params)

            loss.backward()  # fixed precision

            # automatic mixed precision:
            #with amp.scale_loss(loss, optimizer) as scaled_loss:
            #  scaled_loss.backward()

            optimizer.step()

            tr_end = time.time()
            tr_time += tr_end - tr_start

        # Output training stats
        if world_rank == 0:
            log_start = time.time()
            gens = []
            tars = []
            with torch.no_grad():
                for i, data in enumerate(test_data_loader, 0):
                    if i >= 50:
                        break
                    inp, tar = map(lambda x: x.to(device), data)
                    gen = model(inp)
                    gens.append(gen.detach().cpu().numpy())
                    tars.append(tar.detach().cpu().numpy())
            gens = np.concatenate(gens, axis=0)
            tars = np.concatenate(tars, axis=0)

            # Scalars
            args.tboard_writer.add_scalar('G_loss', loss.item(), iters)

            # Plots
            fig = plot_gens_tars(gens, tars)
            #fig, chi, L1score = meanL1(gens, tars)
            #args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
            #args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
            #args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
            #args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
            #args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
            #args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
            #args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
            #
            #fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
            for figiter in range(5):
                figtag = 'test' + str(figiter)
                args.tboard_writer.add_figure(tag=figtag,
                                              figure=fig[figiter],
                                              close=True)
            #log_end = time.time()
            #log_time += log_end - log_start

            # Save checkpoint
            torch.save(
                {
                    'iters': iters,
                    'epoch': epoch,
                    'model_state': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, params.checkpoint_path)

        end = time.time()
        if world_rank == 0:
            logging.info('Time taken for epoch {} is {} sec'.format(
                epoch + 1, end - start))
            logging.info('train step time={}, logging time={}'.format(
                tr_time, log_time))