Example #1
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    if strategy:
        print('***** Number of cores used : ', strategy.num_replicas_in_sync)

    if FLAGS.use_horovod:
        if strategy:
            raise ValueError(
                'Should not run horovod with distribution strategy')

        hvd.init()
        if gpus:
            tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                       'GPU')
            gpu_affinity.set_affinity(hvd.local_rank())

    if FLAGS.use_fp16:
        policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
        tf.keras.mixed_precision.experimental.set_policy(policy)

    run_bert_pretrain(strategy)
Example #2
0
def main(_):
    # Users should always run this script under TF 2.x
    # The container haven't changed version number yet, skip the check.
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if FLAGS.mode == 'export_only':
        export_squad(FLAGS.model_export_path, input_meta_data)
        return

    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)

    if FLAGS.use_horovod:
        if strategy:
            raise ValueError(
                'Should not run horovod with distribution strategy')

        hvd.init()
        if gpus:
            tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                       'GPU')
            gpu_affinity.set_affinity(hvd.local_rank())

    if FLAGS.use_fp16:
        policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
        tf.keras.mixed_precision.experimental.set_policy(policy)

    dllogging = dllogger_class.dllogger_class(FLAGS.dllog_path)
    input_meta_data['dllogging'] = dllogging

    if FLAGS.mode in ('train', 'train_and_predict'):
        train_squad(strategy, input_meta_data)
    if FLAGS.mode in ('predict',
                      'train_and_predict') and (not FLAGS.use_horovod
                                                or hvd.rank() == 0):
        predict_squad(strategy, input_meta_data)
def main():
    args = parse_args()

    hvd.init()
    set_affinity(hvd.local_rank())

    if is_main_process():
        log("Running total processes: {}".format(get_world_size()))
    log("Starting process: {}".format(get_rank()))

    if is_main_process():
        dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                                           filename=args.json_summary),
                                dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
    else:
        dllogger.init(backends=[])

    tf.random.set_seed(args.seed)
    dllogger.log(step="PARAMETER", data={"SEED": args.seed})
    # script parameters
    BATCH_SIZE = args.train_batch_size
    EVAL_BATCH_SIZE = args.predict_batch_size
    USE_XLA = args.xla
    USE_AMP = args.amp
    EPOCHS = args.num_train_epochs

    if not args.do_train:
        EPOCHS = args.num_train_epochs = 1
        log("Since running inference only, setting args.num_train_epochs to 1")

    if not os.path.exists(args.output_dir) and is_main_process():
        os.makedirs(args.output_dir)

    # TensorFlow configuration
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
    tf.config.optimizer.set_jit(USE_XLA)
    #tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
    
    if args.amp:
        policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic")
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print('Compute dtype: %s' % policy.compute_dtype)  # Compute dtype: float16
        print('Variable dtype: %s' % policy.variable_dtype)  # Variable dtype: float32

    if is_main_process():
        log("***** Loading tokenizer and model *****")
    # Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
    electra_model = args.electra_model
    config = ElectraConfig.from_pretrained(electra_model, cache_dir=args.cache_dir)
    config.update({"amp": args.amp})
    if args.vocab_file is None:
        tokenizer = ElectraTokenizer.from_pretrained(electra_model, cache_dir=args.cache_dir)
    else:
        tokenizer = ElectraTokenizer(
            vocab_file=args.vocab_file,
            do_lower_case=args.do_lower_case)

    model = TFElectraForQuestionAnswering.from_pretrained(electra_model, config=config, cache_dir=args.cache_dir, args=args)

    if is_main_process():
        log("***** Loading dataset *****")
    # Load data
    processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
    train_examples = processor.get_train_examples(args.data_dir) if args.do_train else None
    dev_examples = processor.get_dev_examples(args.data_dir) if args.do_predict else None

    if is_main_process():
        log("***** Loading features *****")
    # Load cached features
    squad_version = '2.0' if args.version_2_with_negative else '1.1'
    if args.cache_dir is None:
        args.cache_dir = args.data_dir
    cached_train_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_train-v{4}.json_{1}_{2}_{3}'.format(
        electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
        str(args.max_query_length), squad_version)
    cached_dev_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_dev-v{4}.json_{1}_{2}_{3}'.format(
        electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
        str(args.max_query_length), squad_version)

    try:
        with open(cached_train_features_file, "rb") as reader:
            train_features = pickle.load(reader) if args.do_train else []
        with open(cached_dev_features_file, "rb") as reader:
            dev_features = pickle.load(reader) if args.do_predict else []
    except:
        train_features = (  # TODO: (yy) do on rank 0?
            squad_convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True,
                return_dataset="",
            )
            if args.do_train
            else []
        )
        dev_features = (
            squad_convert_examples_to_features(
                examples=dev_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=False,
                return_dataset="",
            )
            if args.do_predict
            else []
        )
        # Dump Cached features
        if not args.skip_cache and is_main_process():
            if args.do_train:
                log("***** Building Cache Files: {} *****".format(cached_train_features_file))
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
            if args.do_predict:
                log("***** Building Cache Files: {} *****".format(cached_dev_features_file))
                with open(cached_dev_features_file, "wb") as writer:
                    pickle.dump(dev_features, writer)

    len_train_features = len(train_features)
    total_train_steps = int((len_train_features * EPOCHS / BATCH_SIZE) / get_world_size()) + 1
    train_steps_per_epoch = int((len_train_features / BATCH_SIZE) / get_world_size()) + 1
    len_dev_features = len(dev_features)
    total_dev_steps = int((len_dev_features / EVAL_BATCH_SIZE)) + 1

    train_dataset = get_dataset_from_features(train_features, BATCH_SIZE,
                                              v2=args.version_2_with_negative) if args.do_train else []
    dev_dataset = get_dataset_from_features(dev_features, EVAL_BATCH_SIZE, drop_remainder=False, ngpu=1, mode="dev",
                                            v2=args.version_2_with_negative) if args.do_predict else []

    opt = create_optimizer(init_lr=args.learning_rate, num_train_steps=total_train_steps,
                           num_warmup_steps=int(args.warmup_proportion * total_train_steps),
                           weight_decay_rate=args.weight_decay_rate,
                           layerwise_lr_decay=args.layerwise_lr_decay,
                           n_transformer_layers=model.num_hidden_layers)
    if USE_AMP:
        # loss scaling is currently required when using mixed precision
        opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")

    # Define loss function
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    loss_class = tf.keras.losses.BinaryCrossentropy(
        from_logits=True,
        name='binary_crossentropy'
    )
    metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
    model.compile(optimizer=opt, loss=loss, metrics=[metric])
    train_loss_results = []

    if args.do_train and is_main_process():
        log("***** Running training *****")
        log("  Num examples = ", len_train_features)
        log("  Num Epochs = ", args.num_train_epochs)
        log("  Instantaneous batch size per GPU = ", args.train_batch_size)
        log(
            "  Total train batch size (w. parallel, distributed & accumulation) = ",
            args.train_batch_size
            * get_world_size(),
        )
        log("  Total optimization steps =", total_train_steps)

    total_train_time = 0
    latency = []
    for epoch in range(EPOCHS):
        if args.do_train:
            epoch_loss_avg = tf.keras.metrics.Mean()
            epoch_perf_avg = tf.keras.metrics.Mean()
            epoch_start = time.time()

            epoch_iterator = tqdm(train_dataset, total=train_steps_per_epoch, desc="Iteration", mininterval=5,
                                  disable=not is_main_process())
            for iter, inputs in enumerate(epoch_iterator):
                # breaking criterion if max_steps if > 1
                if args.max_steps > 0 and (epoch * train_steps_per_epoch + iter) > args.max_steps:
                    break
                iter_start = time.time()
                # Optimize the model
                loss_value = train_step(model, inputs, loss, USE_AMP, opt, (iter == 0 and epoch == 0),
                                        v2=args.version_2_with_negative, loss_class=loss_class, fp16=USE_AMP)
                epoch_perf_avg.update_state(1. * BATCH_SIZE / (time.time() - iter_start))
                if iter % args.log_freq == 0:
                    if is_main_process():
                        log("\nEpoch: {:03d}, Step:{:6d}, Loss:{:12.8f}, Perf:{:5.0f}, loss_scale:{}, opt_step:{}".format(epoch, iter, loss_value,
                                                                                              epoch_perf_avg.result() * get_world_size(), opt.loss_scale if config.amp else 1,
                                                                                              int(opt.iterations)))
                    dllogger.log(step=(epoch, iter,), data={"step_loss": float(loss_value.numpy()),
                                                            "train_perf": float( epoch_perf_avg.result().numpy() * get_world_size())})

                # Track progress
                epoch_loss_avg.update_state(loss_value)  # Add current batch loss

            # End epoch
            train_loss_results.append(epoch_loss_avg.result())
            total_train_time += float(time.time() - epoch_start)
            # Summarize and save checkpoint at the end of each epoch
            if is_main_process():

                dllogger.log(step=tuple(), data={"e2e_train_time": total_train_time,
                                                 "training_sequences_per_second": float(
                                                     epoch_perf_avg.result().numpy() * get_world_size()),
                                                 "final_loss": float(epoch_loss_avg.result().numpy())})

            if not args.skip_checkpoint:
                if args.ci:
                    checkpoint_name = "{}/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.output_dir, args.version_2_with_negative, epoch + 1)
                else:
                    checkpoint_name = "checkpoints/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.version_2_with_negative, epoch + 1)
                if is_main_process():
                    model.save_weights(checkpoint_name)


        if args.do_predict and (args.evaluate_during_training or epoch == args.num_train_epochs - 1):
            if not args.do_train:
                log("***** Loading checkpoint: {} *****".format(args.init_checkpoint))
                model.load_weights(args.init_checkpoint).expect_partial()

            current_feature_id = 0
            all_results = []
            if is_main_process():
                log("***** Running evaluation *****")
                log("  Num Batches = ", total_dev_steps)
                log("  Batch size = ", args.predict_batch_size)

            raw_infer_start = time.time()
            if is_main_process():
                infer_perf_avg = tf.keras.metrics.Mean()
                dev_iterator = tqdm(dev_dataset, total=total_dev_steps, desc="Iteration", mininterval=5,
                                    disable=not is_main_process())
                for input_ids, input_mask, segment_ids, start_positions, end_positions, cls_index, p_mask, is_impossible in dev_iterator:
                    # training=False is needed only if there are layers with different
                    # behavior during training versus inference (e.g. Dropout).

                    iter_start = time.time()

                    if not args.joint_head:
                        batch_start_logits, batch_end_logits = infer_step(model, input_ids,
                                                                          attention_mask=input_mask,
                                                                          token_type_ids=segment_ids,
                                                                          )[:2]
                        #Synchronize with GPU to compute time
                        _ = batch_start_logits.numpy()
                                                            
                    else:
                        
                        outputs = infer_step(model, input_ids,
                                             attention_mask=input_mask,
                                             token_type_ids=segment_ids,
                                             cls_index=cls_index,
                                             p_mask=p_mask,
                                             )
                        #Synchronize with GPU to compute time
                        _ = outputs[0].numpy()

                    infer_time = (time.time() - iter_start)
                    infer_perf_avg.update_state(1. * EVAL_BATCH_SIZE / infer_time)
                    latency.append(infer_time)

                    for iter_ in range(input_ids.shape[0]):

                        if not args.joint_head:
                            start_logits = batch_start_logits[iter_].numpy().tolist()
                            end_logits = batch_end_logits[iter_].numpy().tolist()
                            dev_feature = dev_features[current_feature_id]
                            current_feature_id += 1
                            unique_id = int(dev_feature.unique_id)
                            all_results.append(RawResult(unique_id=unique_id,
                                                         start_logits=start_logits,
                                                         end_logits=end_logits))
                        else:
                            dev_feature = dev_features[current_feature_id]
                            current_feature_id += 1
                            unique_id = int(dev_feature.unique_id)
                            output = [output[iter_].numpy().tolist() for output in outputs]

                            start_logits = output[0]
                            start_top_index = output[1]
                            end_logits = output[2]
                            end_top_index = output[3]
                            cls_logits = output[4]
                            result = SquadResult(
                                unique_id,
                                start_logits,
                                end_logits,
                                start_top_index=start_top_index,
                                end_top_index=end_top_index,
                                cls_logits=cls_logits,
                            )

                            all_results.append(result)

                # Compute and save predictions
                answers, nbest_answers = get_answers(dev_examples, dev_features, all_results, args)

                output_prediction_file = os.path.join(args.output_dir, "predictions.json")
                output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
                e2e_infer_time = time.time() - raw_infer_start
                # if args.version_2_with_negative:
                #     output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
                # else:
                #     output_null_log_odds_file = None
                with open(output_prediction_file, "w") as f:
                    f.write(json.dumps(answers, indent=4) + "\n")
                with open(output_nbest_file, "w") as f:
                    f.write(json.dumps(nbest_answers, indent=4) + "\n")

                if args.do_eval:
                    if args.version_2_with_negative:
                        dev_file = "dev-v2.0.json"
                    else:
                        dev_file = "dev-v1.1.json"

                    eval_out = subprocess.check_output([sys.executable, args.eval_script,
                                                        args.data_dir + "/" + dev_file, output_prediction_file])
                    log(eval_out.decode('UTF-8'))
                    scores = str(eval_out).strip()
                    exact_match = float(scores.split(":")[1].split(",")[0])
                    if args.version_2_with_negative:
                        f1 = float(scores.split(":")[2].split(",")[0])
                    else:
                        f1 = float(scores.split(":")[2].split("}")[0])

                    log("Epoch: {:03d} Results: {}".format(epoch, eval_out.decode('UTF-8')))
                    log("**EVAL SUMMARY** - Epoch: {:03d},  EM: {:6.3f}, F1: {:6.3f}, Infer_Perf: {:4.0f} seq/s"
                          .format(epoch, exact_match, f1, infer_perf_avg.result()))

                latency_all = sorted(latency)[:-2]
                log(
                    "**LATENCY SUMMARY** - Epoch: {:03d},  Ave: {:6.3f} ms, 90%: {:6.3f} ms, 95%: {:6.3f} ms, 99%: {:6.3f} ms"
                    .format(epoch, sum(latency_all) / len(latency_all) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.9)]) / int(len(latency_all) * 0.9) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.95)]) / int(len(latency_all) * 0.95) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.99)]) / int(len(latency_all) * 0.99) * 1000,
                            ))
                dllogger.log(step=tuple(),
                             data={"inference_sequences_per_second": float(infer_perf_avg.result().numpy()), 
                                   "e2e_inference_time": e2e_infer_time})

    if is_main_process() and args.do_train and args.do_eval:
        log(
            "**RESULTS SUMMARY** - EM: {:6.3f}, F1: {:6.3f}, Train_Time: {:4.0f} s, Train_Perf: {:4.0f} seq/s, Infer_Perf: {:4.0f} seq/s"
            .format(exact_match, f1, total_train_time, epoch_perf_avg.result() * get_world_size(),
                    infer_perf_avg.result()))
        dllogger.log(step=tuple(), data={"exact_match": exact_match, "F1": f1})
Example #4
0
def main(args):
    # Enable CuDNN autotuner
    nproc_per_node = torch.cuda.device_count()
    if args.affinity != 'disabled':
        affinity = gpu_affinity.set_affinity(
                args.local_rank,
                nproc_per_node,
                args.affinity
            )
        print(f'{args.local_rank}: thread affinity: {affinity}')


    torch.backends.cudnn.benchmark = True

    ### INIT DISTRIBUTED
    if args.distributed_world_size > 1:
        args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank))
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        args.distributed_world_size = int(os.environ['WORLD_SIZE'])
        args.distributed_rank = dist.get_rank()
        print_once(f'Distributed training with {args.distributed_world_size} GPUs')
        torch.cuda.synchronize()

    if args.seed:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    setup_logger(args)

    config = CONFIGS[args.dataset]()
    if args.overwrite_config:
        config.__dict__.update(json.loads(args.overwrite_config))

    dllogger.log(step='HPARAMS', data={**vars(args), **vars(config)}, verbosity=1)

    model = TemporalFusionTransformer(config).cuda()
    if args.ema_decay:
        model_ema = ModelEma(model, decay=args.ema_decay)

    print_once('Model params: {}'.format(sum(p.numel() for p in model.parameters())))
    criterion = QuantileLoss(config).cuda()
    optimizer = FusedAdam(model.parameters(), lr=args.lr)
    if args.use_amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic")
    if args.distributed_world_size > 1:
        #model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
        model = DDP(model)

    train_loader, valid_loader, test_loader = load_dataset(args, config)

    global_step = 0
    perf_meter = PerformanceMeter()

    for epoch in range(args.epochs):
        start = time.time()
        dllogger.log(step=global_step, data={'epoch': epoch}, verbosity=1)

        model.train() 
        for local_step, batch in enumerate(train_loader):
            perf_meter.reset_current_lap()
            batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
            predictions = model(batch)
            targets = batch['target'][:,config.encoder_length:,:]
            p_losses = criterion(predictions, targets)
            loss = p_losses.sum()

            if args.use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            if not args.grad_accumulation or (global_step+1) % args.grad_accumulation == 0:
                if args.clip_grad:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                optimizer.step()
                optimizer.zero_grad()
                if args.ema_decay:
                    model_ema.update(model)

            if args.distributed_world_size > 1:
                dist.all_reduce(p_losses)
                p_losses /= args.distributed_world_size
                loss = p_losses.sum()

            torch.cuda.synchronize()
            ips = perf_meter.update(args.batch_size * args.distributed_world_size,
                    exclude_from_total=local_step in [0, len(train_loader)-1])

            log_dict = {'P10':p_losses[0].item(), 'P50':p_losses[1].item(), 'P90':p_losses[2].item(), 'loss': loss.item(), 'items/s':ips}
            dllogger.log(step=global_step, data=log_dict, verbosity=1)
            global_step += 1

        validate(args, config, model_ema if args.ema_decay else model, criterion, valid_loader, global_step)

        if validate.early_stop_c >= args.early_stopping:
            print_once('Early stopping')
            break

    ### TEST PHASE ###
    state_dict = torch.load(os.path.join(args.results, 'checkpoint.pt'), map_location='cpu')
    if isinstance(model, DDP):
        model.module.load_state_dict(state_dict['model'])
    else:
        model.load_state_dict(state_dict['model'])
    model.cuda().eval()

    tgt_scalers = pickle.load(open(os.path.join(args.data_path, 'tgt_scalers.bin'), 'rb'))
    cat_encodings = pickle.load(open(os.path.join(args.data_path,'cat_encodings.bin'), 'rb'))

    unscaled_predictions, unscaled_targets, _, _ = predict(args, config, model, test_loader, tgt_scalers, cat_encodings)
    losses = QuantileLoss(config)(unscaled_predictions, unscaled_targets)
    normalizer = unscaled_targets.abs().mean()
    quantiles = 2 * losses / normalizer

    if args.distributed_world_size > 1:
        quantiles = quantiles.cuda()
        dist.all_reduce(quantiles)
        quantiles /= args.distributed_world_size

    quantiles = {'test_p10': quantiles[0].item(), 'test_p50': quantiles[1].item(), 'test_p90': quantiles[2].item(), 'sum':sum(quantiles).item()}
    finish_log = {**quantiles, 'average_ips':perf_meter.avg, 'convergence_step':validate.conv_step}
    dllogger.log(step=(), data=finish_log, verbosity=1)
Example #5
0
def main(args):
    ## Distributed computing

    # utility for synchronization
    def reduce_tensor(tensor, reduce_sum=False):
        rt = tensor.clone()
        torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
        return rt if reduce_sum else (rt / world_size)

    # enable distributed computing
    if args.distributed:
        set_affinity(args.local_rank)
        num_devices = torch.cuda.device_count()
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        node_rank = args.node_rank
        global_rank = node_rank * num_devices + args.local_rank
        world_size = torch.distributed.get_world_size(
        )  #os.environ['WORLD_SIZE']
    else:
        global_rank, num_devices, world_size = 0, 1, 1

    ## Data format: batch(0) x steps(1) x height(2) x width(3) x channels(4)

    # batch_size (0)
    total_batch_size = args.batch_size
    assert total_batch_size % world_size == 0, \
        'The batch_size is not divisible by world_size.'
    batch_size = total_batch_size // world_size

    # steps (1)
    total_frames = args.future_frames + args.input_frames

    # frame format (2, 3)
    img_resize = (args.img_height != args.img_height_u) or (args.img_width !=
                                                            args.img_width_u)

    ## Model preparation (Conv-LSTM or Conv-TT-LSTM)

    # size of the neural network model (depth and width)
    layers_per_block = (3, 3, 3, 3)
    hidden_channels = (32, 48, 48, 32)
    skip_stride = 2

    # construct the model with the specified hyper-parameters
    model = ConvLSTMNet(
        # architecture of the model
        layers_per_block=layers_per_block,
        hidden_channels=hidden_channels,
        input_channels=1,
        skip_stride=skip_stride,
        cell_params={
            "steps": 3,
            "order": 3,
            "ranks": 8
        },
        # parameters of convolutional operation
        kernel_size=5,
        bias=True).cuda()

    if args.distributed:
        model = DDP(model, device_ids=[args.local_rank])

    PSmodel = PSmodels.PerceptualLoss(model='net-lin',
                                      net='alex',
                                      use_gpu=True,
                                      gpu_ids=[args.local_rank])

    ## Dataset Preparation (KTH, UCF, tinyUCF)
    assert args.dataset in ["MNIST", "KTH"], \
        "The dataset is not currently supported."

    Dataset = {"KTH": KTH_Dataset, "MNIST": MNIST_Dataset}[args.dataset]

    # path to the dataset folder
    DATA_DIR = args.data_path

    assert os.path.exists(DATA_DIR), \
        "The dataset folder does not exist. "+DATA_DIR

    assert os.path.exists(DATA_DIR), \
        "The test dataset does not exist. "+DATA_DIR

    test_dataset = Dataset({
        "path": DATA_DIR,
        "unique_mode": True,
        "num_frames": total_frames,
        "num_samples": args.test_samples,
        "height": args.img_height,
        "width": args.img_width,
        "channels": 1,
        'training': False
    })

    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              drop_last=True,
                                              num_workers=num_devices * 4,
                                              pin_memory=True,
                                              sampler=test_sampler)

    test_samples = len(test_loader) * total_batch_size

    MODEL_FILE = args.model_path

    assert os.path.exists(MODEL_FILE), \
        "The specified model is not found in the folder."

    checkpoint = torch.load(MODEL_FILE)
    eval_epoch = checkpoint.get("epoch", 0)
    model.load_state_dict(checkpoint["model_state_dict"])

    ## Main script for test phase
    MSE_ = torch.zeros((args.future_frames), dtype=torch.float32).cuda()
    PSNR_ = torch.zeros((args.future_frames), dtype=torch.float32).cuda()
    SSIM_ = torch.zeros((args.future_frames), dtype=torch.float32).cuda()
    PIPS_ = torch.zeros((args.future_frames), dtype=torch.float32).cuda()

    with torch.no_grad():
        model.eval()

        samples = 0
        for it, frames in enumerate(test_loader):
            samples += total_batch_size

            frames = torch.mean(frames, dim=-1, keepdim=True)

            if img_resize:
                frames_ = frames.cpu().numpy()
                frames = np.zeros((batch_size, total_frames, args.img_height_u,
                                   args.img_width_u, 1),
                                  dtype=np.float32)

                for b in range(batch_size):
                    for t in range(total_frames):
                        frames[b, t] = skimage.transform.resize(
                            frames_[b, t],
                            (args.img_height_u, args.img_width_u))

                frames = torch.from_numpy(frames)

            # 5-th order: batch_size x total_frames x channels x height x width
            frames = frames.permute(0, 1, 4, 2, 3).cuda()
            inputs = frames[:, :args.input_frames]
            origin = frames[:, -args.future_frames:]

            pred = model(inputs,
                         input_frames=args.input_frames,
                         future_frames=args.future_frames,
                         output_frames=args.future_frames,
                         teacher_forcing=False)

            # clamp the output to [0, 1]
            pred = torch.clamp(pred, min=0, max=1)

            # accumlate the statistics per frame
            for t in range(-args.future_frames, 0):
                origin_, pred_ = origin[:, t], pred[:, t]

                origin_ = origin_.repeat([1, 3, 1, 1])
                pred_ = pred_.repeat([1, 3, 1, 1])

                dist = PSmodel(origin_, pred_)
                PIPS_[t] += torch.sum(dist).item()

            origin = origin.permute(0, 1, 3, 4, 2).cpu().numpy()
            pred = pred.permute(0, 1, 3, 4, 2).cpu().numpy()

            for t in range(-args.future_frames, 0):
                for i in range(batch_size):
                    origin_, pred_ = origin[i, t], pred[i, t]

                    origin_ = np.squeeze(origin_, axis=-1)
                    pred_ = np.squeeze(pred_, axis=-1)

                    MSE_[t] += skimage.metrics.mean_squared_error(
                        origin_, pred_)
                    PSNR_[t] += skimage.metrics.peak_signal_noise_ratio(
                        origin_, pred_)
                    SSIM_[t] += skimage.metrics.structural_similarity(
                        origin_, pred_)

            if args.distributed:
                MSE = reduce_tensor(MSE_, reduce_sum=True) / samples
                PSNR = reduce_tensor(PSNR_, reduce_sum=True) / samples
                SSIM = reduce_tensor(SSIM_, reduce_sum=True) / samples
                PIPS = reduce_tensor(PIPS_, reduce_sum=True) / samples
            else:
                MSE = MSE_ / samples
                PSNR = PSNR_ / samples
                SSIM = SSIM_ / samples
                PIPS = PIPS_ / samples

            if ((it + 1) % 50 == 0
                    or it + 1 == len(test_loader)) and args.local_rank == 0:
                print((it + 1) * total_batch_size, '/', test_samples,
                      ": MSE:  ",
                      torch.mean(MSE).cpu().item() * 1e3, "; PSNR: ",
                      torch.mean(PSNR).cpu().item(), "; SSIM: ",
                      torch.mean(SSIM).cpu().item(), ";LPIPS: ",
                      torch.mean(PIPS).cpu().item())

        if args.distributed:
            MSE = reduce_tensor(MSE_, reduce_sum=True) / test_samples
            PSNR = reduce_tensor(PSNR_, reduce_sum=True) / test_samples
            SSIM = reduce_tensor(SSIM_, reduce_sum=True) / test_samples
            PIPS = reduce_tensor(PIPS_, reduce_sum=True) / test_samples
        else:
            MSE = MSE_ / test_samples
            PSNR = PSNR_ / test_samples
            SSIM = SSIM_ / test_samples
            PIPS = PIPS_ / test_samples

        MSE_AVG = torch.mean(MSE).cpu().item()
        PSNR_AVG = torch.mean(PSNR).cpu().item()
        SSIM_AVG = torch.mean(SSIM).cpu().item()
        PIPS_AVG = torch.mean(PIPS).cpu().item()

        if args.local_rank == 0:
            print(
                "Epoch \t{} \tMSE: \t{} (x1e-3) \tPSNR: \t{} \tSSIM: \t{} \tLPIPS: \t{}"
                .format(eval_epoch, 1e3 * MSE_AVG, PSNR_AVG, SSIM_AVG,
                        PIPS_AVG))
Example #6
0
def main(e2e_start_time):
    # Parse essential argumentss
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", required=True)
    parser.add_argument("--model_size",
                        default="base",
                        type=str,
                        help="base or large")
    parser.add_argument("--pretrain_tfrecords", type=str)
    parser.add_argument("--phase2", action='store_true')
    parser.add_argument("--fp16_compression", action='store_true')
    parser.add_argument("--amp",
                        action='store_true',
                        help="Whether to use fp16.")
    parser.add_argument("--xla",
                        action='store_true',
                        help="Whether to use xla.")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--num_train_steps", type=int)
    parser.add_argument("--num_warmup_steps", type=int)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--train_batch_size", type=int)
    parser.add_argument("--max_seq_length", type=int)

    parser.add_argument("--mask_prob", type=float)
    parser.add_argument("--disc_weight", type=float)
    parser.add_argument("--generator_hidden_size", type=float)

    parser.add_argument("--log_freq",
                        type=int,
                        default=10,
                        help="Training metrics logging frequency")
    parser.add_argument("--save_checkpoints_steps", type=int)
    parser.add_argument("--keep_checkpoint_max", type=int)
    parser.add_argument("--restore_checkpoint", default=None, type=str)
    parser.add_argument("--load_weights", action='store_true')
    parser.add_argument("--weights_dir")

    parser.add_argument("--optimizer",
                        default="adam",
                        type=str,
                        help="adam or lamb")
    parser.add_argument(
        "--skip_adaptive",
        action='store_true',
        help="Whether to apply adaptive LR on LayerNorm and biases")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=1,
                        help="Number of Gradient Accumulation steps")
    parser.add_argument("--lr_decay_power",
                        type=float,
                        default=0.5,
                        help="LR decay power")
    parser.add_argument("--opt_beta_1",
                        type=float,
                        default=0.878,
                        help="Optimizer beta1")
    parser.add_argument("--opt_beta_2",
                        type=float,
                        default=0.974,
                        help="Optimizer beta2")
    parser.add_argument("--end_lr", type=float, default=0.0, help="Ending LR")
    parser.add_argument("--log_dir",
                        type=str,
                        default=None,
                        help="Path to store logs")
    parser.add_argument("--results_dir",
                        type=str,
                        default=None,
                        help="Path to store all model results")
    parser.add_argument("--skip_checkpoint",
                        action='store_true',
                        default=False,
                        help="Path to store logs")
    parser.add_argument(
        '--json-summary',
        type=str,
        default=None,
        help=
        'If provided, the json summary will be written to the specified file.')
    args = parser.parse_args()
    config = PretrainingConfig(**args.__dict__)
    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    # Set up tensorflow
    hvd.init()

    args.log_dir = config.log_dir
    # DLLogger
    setup_logger(args)

    set_affinity(hvd.local_rank())
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')
    tf.config.optimizer.set_jit(config.xla)
    #tf.config.optimizer.set_experimental_options({"auto_mixed_precision": config.amp})

    if config.amp:
        policy = tf.keras.mixed_precision.experimental.Policy(
            "mixed_float16", loss_scale="dynamic")
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print('Compute dtype: %s' %
              policy.compute_dtype)  # Compute dtype: float16
        print('Variable dtype: %s' %
              policy.variable_dtype)  # Variable dtype: float32

    #tf.random.set_seed(config.seed)

    # Set up config cont'
    if config.load_weights and config.restore_checkpoint:
        raise ValueError(
            "`load_weights` and `restore_checkpoint` should not be on at the same time."
        )
    if config.phase2 and not config.restore_checkpoint:
        raise ValueError(
            "`phase2` cannot be used without `restore_checkpoint`.")
    utils.heading("Config:")
    log_config(config)

    # Save pretrain configs
    pretrain_config_json = os.path.join(config.checkpoints_dir,
                                        'pretrain_config.json')
    if is_main_process():
        utils.write_json(config.__dict__, pretrain_config_json)
        log("Configuration saved in {}".format(pretrain_config_json))

    # Set up model
    model = PretrainingModel(config)

    # Set up metrics
    metrics = dict()
    metrics["train_perf"] = tf.keras.metrics.Mean(name="train_perf")
    metrics["total_loss"] = tf.keras.metrics.Mean(name="total_loss")
    metrics["masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
        name="masked_lm_accuracy")
    metrics["masked_lm_loss"] = tf.keras.metrics.Mean(name="masked_lm_loss")
    if config.electra_objective:
        metrics["sampled_masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
            name="sampled_masked_lm_accuracy")
        if config.disc_weight > 0:
            metrics["disc_loss"] = tf.keras.metrics.Mean(name="disc_loss")
            metrics["disc_auc"] = tf.keras.metrics.AUC(name="disc_auc")
            metrics["disc_accuracy"] = tf.keras.metrics.Accuracy(
                name="disc_accuracy")
            metrics["disc_precision"] = tf.keras.metrics.Accuracy(
                name="disc_precision")
            metrics["disc_recall"] = tf.keras.metrics.Accuracy(
                name="disc_recall")

    # Set up tensorboard
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(
        config.log_dir, current_time,
        'train_' + str(get_rank()) + '_of_' + str(get_world_size()))
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    # Set up dataset
    dataset = pretrain_utils.get_dataset(config,
                                         config.train_batch_size,
                                         world_size=get_world_size(),
                                         rank=get_rank())
    train_iterator = iter(dataset)

    # Set up optimizer
    optimizer = create_optimizer(init_lr=config.learning_rate,
                                 num_train_steps=config.num_train_steps,
                                 num_warmup_steps=config.num_warmup_steps,
                                 weight_decay_rate=config.weight_decay_rate,
                                 optimizer=config.optimizer,
                                 skip_adaptive=config.skip_adaptive,
                                 power=config.lr_decay_power,
                                 beta_1=config.opt_beta_1,
                                 beta_2=config.opt_beta_2,
                                 end_lr=config.end_lr)

    accumulator = GradientAccumulator()
    if config.amp:
        optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
            optimizer, "dynamic")

    # Set up model checkpoint
    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     phase2=tf.Variable(False),
                                     optimizer=optimizer,
                                     model=model)
    manager = tf.train.CheckpointManager(
        checkpoint,
        config.checkpoints_dir,
        max_to_keep=config.keep_checkpoint_max)
    if config.restore_checkpoint and config.restore_checkpoint != "latest":
        checkpoint.restore(config.restore_checkpoint)
        log(" ** Restored model checkpoint from {}".format(
            config.restore_checkpoint))
    elif config.restore_checkpoint and config.restore_checkpoint == "latest" and manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        log(" ** Restored model checkpoint from {}".format(
            manager.latest_checkpoint))
    elif config.load_weights:
        model.generator(model.generator.dummy_inputs)
        model.discriminator(model.discriminator.dummy_inputs)
        model.generator.load_weights(
            os.path.join(config.weights_dir, 'generator', 'tf_model.h5'))
        model.discriminator.load_weights(
            os.path.join(config.weights_dir, 'discriminator', 'tf_model.h5'))
    else:
        log(" ** Initializing from scratch.")

    restore_iterator = bool(
        config.restore_checkpoint) and config.restore_checkpoint == "latest"
    # Initialize global step for phase2
    if config.phase2 and not bool(checkpoint.phase2):
        optimizer.iterations.assign(0)
        checkpoint.step.assign(0)
        checkpoint.phase2.assign(True)
        restore_iterator = False
    if bool(checkpoint.phase2):
        manager = tf.train.CheckpointManager(
            checkpoint,
            config.checkpoints_dir,
            checkpoint_name='ckpt-p2',
            max_to_keep=config.keep_checkpoint_max)

    # Set up iterator checkpoint
    iter_checkpoint = tf.train.Checkpoint(train_iterator=train_iterator,
                                          world_size=tf.Variable(
                                              get_world_size()),
                                          rank=tf.Variable(get_rank()))
    iter_manager = tf.train.CheckpointManager(
        iter_checkpoint,
        os.path.join(config.checkpoints_dir,
                     'iter_ckpt_rank_' + '{:02}'.format(get_rank())),
        checkpoint_name='iter_ckpt_rank_' + '{:02}'.format(get_rank()),
        max_to_keep=config.keep_checkpoint_max)
    if restore_iterator and iter_manager.latest_checkpoint:
        ckpt_world_size = tf.train.load_variable(
            iter_manager.latest_checkpoint,
            'world_size/.ATTRIBUTES/VARIABLE_VALUE')
        if ckpt_world_size == get_world_size():
            iter_checkpoint.restore(iter_manager.latest_checkpoint)
            log(" ** Restored iterator checkpoint from {}".format(
                iter_manager.latest_checkpoint),
                all_rank=True)

    utils.heading("Running training")
    accumulator.reset()
    train_start, start_step = time.time(), int(checkpoint.step) - 1
    local_step = 0
    saved_ckpt = False
    while int(checkpoint.step) <= config.num_train_steps:
        saved_ckpt = False
        step = int(checkpoint.step)
        features = next(train_iterator)
        iter_start = time.time()

        # if step == 200: tf.profiler.experimental.start(logdir=train_log_dir)
        total_loss, eval_fn_inputs = train_one_step(
            config,
            model,
            optimizer,
            features,
            accumulator,
            local_step == 1,
            take_step=local_step % args.gradient_accumulation_steps == 0)
        # if step == 300: tf.profiler.experimental.stop()

        metrics["train_perf"].update_state(config.train_batch_size *
                                           get_world_size() /
                                           (time.time() - iter_start))
        metrics["total_loss"].update_state(values=total_loss)
        metric_fn(config, metrics, eval_fn_inputs)

        if (step % args.log_freq
                == 0) and (local_step % args.gradient_accumulation_steps == 0):
            log_info_dict = {
                k: float(v.result().numpy() *
                         100) if "accuracy" in k else float(v.result().numpy())
                for k, v in metrics.items()
            }
            dllogger.log(step=(step, ), data=log_info_dict, verbosity=0)
            log('Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
                'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f}, Loss Scaler: {loss_scale}, Elapsed: {elapsed}, ETA: {eta}, '
                .format(step=step,
                        **log_info_dict,
                        loss_scale=optimizer.loss_scale if config.amp else 1,
                        elapsed=utils.get_readable_time(time.time() -
                                                        train_start),
                        eta=utils.get_readable_time(
                            (time.time() - train_start) / (step - start_step) *
                            (config.num_train_steps - step))),
                all_rank=True)

            with train_summary_writer.as_default():
                for key, m in metrics.items():
                    tf.summary.scalar(key, m.result(), step=step)

            if int(checkpoint.step) < config.num_train_steps:
                for m in metrics.values():
                    m.reset_states()

        #Print allreduced metrics on the last step
        if int(checkpoint.step) == config.num_train_steps and (
                local_step % args.gradient_accumulation_steps == 0):
            log_info_dict = {
                k: float(hvd.allreduce(v.result()).numpy() * 100) if "accuracy"
                in k else float(hvd.allreduce(v.result()).numpy())
                for k, v in metrics.items()
            }
            log_info_dict["training_sequences_per_second"] = log_info_dict[
                "train_perf"]
            log_info_dict["final_loss"] = log_info_dict["total_loss"]
            log_info_dict["e2e_train_time"] = time.time() - e2e_start_time
            dllogger.log(step=(), data=log_info_dict, verbosity=0)
            log('<FINAL STEP METRICS> Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
                'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f},'.
                format(step=step, **log_info_dict),
                all_rank=False)

        if local_step % args.gradient_accumulation_steps == 0:
            checkpoint.step.assign(int(optimizer.iterations))

        local_step += 1
        if not config.skip_checkpoint and (
                local_step %
            (config.save_checkpoints_steps * args.gradient_accumulation_steps)
                == 0):
            saved_ckpt = True
            if is_main_process():
                save_path = manager.save(checkpoint_number=step)
                log(" ** Saved model checkpoint for step {}: {}".format(
                    step, save_path))
            iter_save_path = iter_manager.save(checkpoint_number=step)
            log(" ** Saved iterator checkpoint for step {}: {}".format(
                step, iter_save_path),
                all_rank=True)

    step = (int(checkpoint.step) - 1)
    dllogger.flush()
    if not config.skip_checkpoint and not saved_ckpt:
        if is_main_process():
            save_path = manager.save(checkpoint_number=step)
            log(" ** Saved model checkpoint for step {}: {}".format(
                step, save_path))
        iter_save_path = iter_manager.save(checkpoint_number=step)
        log(" ** Saved iterator checkpoint for step {}: {}".format(
            step, iter_save_path),
            all_rank=True)

    return args