def main():
    # torch.autograd.set_detect_anomaly(True)

    args = parse_args(sys.argv[1:])

    device = torch.device(args.device)
    if device.type == 'cuda':
        torch.cuda.set_device(device)
    if args.timers:
        timers = SynchronizedWallClockTimer()
    else:
        timers = FakeTimer()

    model = init_model(args, device)

    serializer = Serializer(args.model, args.num_checkpoints,
                            args.permanent_interval)

    args.do_not_continue = (args.do_not_continue
                            or len(serializer.list_known_steps()) == 0)
    last_step = (0 if args.do_not_continue else
                 serializer.list_known_steps()[-1])

    optimizer, scheduler = construct_train_tools(args,
                                                 model,
                                                 passed_steps=last_step)

    losses = init_losses(args.shape,
                         args.bs,
                         model,
                         device,
                         sequence_length=args.prefix_length +
                         args.suffix_length + 1,
                         timers=timers)

    # allow only manual flush
    logger = SummaryWriter(str(args.log_path),
                           max_queue=100000000,
                           flush_secs=100000000)

    periodic_hooks, hooks = create_hooks(args, model, optimizer, losses,
                                         logger, serializer)

    if not args.do_not_continue:
        global_step, state = serializer.load_checkpoint(model,
                                                        last_step,
                                                        optimizer=optimizer,
                                                        device=device)
        samples_passed = state.pop('samples_passed', global_step * args.bs)
    else:
        global_step = 0
        samples_passed = 0
        hooks['serialization'](global_step, samples_passed)

    loader = get_dataloader(get_trainset_params(args),
                            sample_idx=samples_passed,
                            process_only_once=False)

    if not args.skip_validation:
        hooks['validation'](global_step, samples_passed)

    with Profiler(args.profiling, args.model/'profiling'), \
            GPUMonitor(args.log_path):
        train(model,
              device,
              loader,
              optimizer,
              args.training_steps,
              scheduler=scheduler,
              evaluator=losses,
              logger=logger,
              weights=args.loss_weights,
              is_raw=args.is_raw,
              accumulation_steps=args.accum_step,
              timers=timers,
              hooks=periodic_hooks,
              init_step=global_step,
              init_samples_passed=samples_passed,
              max_events_per_batch=args.max_events_per_batch)

    samples = samples_passed + (args.training_steps - global_step) * args.bs
    hooks['serialization'](args.training_steps, samples)
    if not args.skip_validation:
        hooks['validation'](args.training_steps, samples)
Ejemplo n.º 2
0
def main():
    # torch.autograd.set_detect_anomaly(True)

    args = parse_args()

    device = torch.device(args.device)
    torch.cuda.set_device(device)
    if args.timers:
        timers = SynchronizedWallClockTimer()
    else:
        timers = FakeTimer()

    model = init_model(args, device)

    loader = get_dataloader(get_trainset_params(args))

    serializer = Serializer(args.model, args.num_checkpoints,
                            args.permanent_interval)

    args.do_not_continue = (args.do_not_continue
                            or len(serializer.list_known_steps()) == 0)
    last_step = (0 if args.do_not_continue else
                 serializer.list_known_steps()[-1])

    optimizer, scheduler = construct_train_tools(args,
                                                 model,
                                                 passed_steps=last_step)

    losses = init_losses(get_resolution(args),
                         args.bs,
                         model,
                         device,
                         timers=timers)

    logger = SummaryWriter(str(args.log_path))

    periodic_hooks, hooks = create_hooks(args, model, optimizer, losses,
                                         logger, serializer)

    if not args.do_not_continue:
        global_step, state = serializer.load_checkpoint(model,
                                                        last_step,
                                                        optimizer=optimizer,
                                                        device=device)
        samples_passed = state.pop('samples_passed', global_step * args.bs)
    else:
        global_step = 0
        samples_passed = 0
        hooks['serialization'](global_step, samples_passed)

    hooks['validation'](global_step, samples_passed)

    with Profiler(args.profiling, args.model / 'profiling'):
        train(model,
              device,
              loader,
              optimizer,
              args.training_steps,
              scheduler=scheduler,
              evaluator=losses,
              logger=logger,
              weights=args.loss_weights,
              is_raw=args.is_raw,
              accumulation_steps=args.accum_step,
              timers=timers,
              hooks=periodic_hooks,
              init_step=global_step,
              init_samples_passed=samples_passed)

    samples = samples_passed + (args.training_steps - global_step) * args.bs
    hooks['serialization'](args.training_steps, samples)
    hooks['validation'](args.training_steps, samples)