Exemplo n.º 1
0
def main():
    args = get_args()
    cfg = Config.fromfile(args.config)
    cfg.fold = args.fold
    global device
    cfg.device = device
    log.info(cfg)

    # torch.cuda.set_device(cfg.gpu)
    util.set_seed(cfg.seed)
    log.info(f'setting seed = {cfg.seed}')

    # setup -------------------------------------
    for f in ['checkpoint', 'train', 'valid', 'test', 'backup']:
        os.makedirs(cfg.workdir + '/' + f, exist_ok=True)
    if 0:  #not work perfect
        file.backup_project_as_zip(
            PROJECT_PATH,
            cfg.workdir + '/backup/code.train.%s.zip' % IDENTIFIER)

    ## model ------------------------------------
    model = model_factory.get_model(cfg)

    # multi-gpu----------------------------------
    if torch.cuda.device_count() > 1 and len(cfg.gpu) > 1:
        log.info(f"Let's use {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    model.to(device)

    ## train model-------------------------------
    do_train(cfg, model)
Exemplo n.º 2
0
def bootstrap():
    config = get_config()
    set_seed(seed=config[GENERAL][SEED])
    log_file_name = config[LOG][FILE_PATH]
    print("Writing logs to file name: {}".format(log_file_name))
    logging.basicConfig(filename=log_file_name,
                        format='%(message)s',
                        filemode='w',
                        level=logging.DEBUG)
    return config
Exemplo n.º 3
0
def set_env(args):
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
        and not args.resume_train
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(
            address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)
Exemplo n.º 4
0
def main():
    args = get_args()
    cfg = Config.fromfile(args.config)
    cfg.fold = args.fold
    global device
    cfg.device = device
    log.info(cfg)

    # torch.cuda.set_device(cfg.gpu)
    util.set_seed(cfg.seed)
    log.info(f'setting seed = {cfg.seed}')

    # setup -------------------------------------
    for f in ['checkpoint', 'train', 'valid', 'test', 'backup']:
        os.makedirs(cfg.workdir + '/' + f, exist_ok=True)

    make_submission(cfg)
Exemplo n.º 5
0
def main(config):
    set_seed(config['seed'])
    weights_path = get_weights_dir(config)

    device = cuda_setup(config['cuda'], config['gpu'])
    print(f'Device variable: {device}')
    if device.type == 'cuda':
        print(f'Current CUDA device: {torch.cuda.current_device()}')

    print('\n')
    all_metrics(config, weights_path, device, None, None)
    print('\n')

    set_seed(config['seed'])
    jsd_epoch, jsd_value = jsd(config, weights_path, device)
    print('\n')

    set_seed(config['seed'])
    all_metrics(config, weights_path, device, jsd_epoch, jsd_value)
Exemplo n.º 6
0
def train(args, model, tokenizer, f, train_fn):
    """ Train the model """
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs

    print('????t_total', t_total)
    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3", "embeddingHead"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    # if getattr_recursive(model, "roberta.encoder.layer") is not None:
    #     for layer in model.roberta.encoder.layer:
    #         optimizer_grouped_parameters.append({"params": layer.parameters()})
    #         for p in layer.parameters():
    #             layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    if args.scheduler.lower() == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    elif args.scheduler.lower() == "cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception(
            "Scheduler {0} not recognized! Can only be linear or cosine".
            format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(
                    args.model_name_or_path,
                    "scheduler.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                args.expected_train_size // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model")

    tr_loss, logging_loss = 0.0, 0.0
    tr_acc, logging_acc = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    #print('???',args.local_rank)
    #assert 1==0, "?????"
    for m_epoch in train_iterator:
        f.seek(0)
        sds = StreamingDataset(f, train_fn)
        epoch_iterator = DataLoader(sds,
                                    batch_size=args.per_gpu_train_batch_size,
                                    num_workers=1)
        for step, batch in tqdm(enumerate(epoch_iterator),
                                desc="Iteration",
                                disable=args.local_rank not in [-1, 0]):
            #assert 1==0, "?????"
            # Skip past any already trained steps if resuming training
            #assert 1==0, steps_trained_in_current_epoch
            if not args.reset_iter:
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)
            # print('???',*batch)
            # assert 1==0, "!!!!!"

            if (step + 1) % args.gradient_accumulation_steps == 0:

                outputs = model(*batch)
            else:
                with model.no_sync():
                    # print('???',*batch)
                    # assert 1==0
                    outputs = model(*batch)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
            acc = outputs[1]
            #print('???',acc)
            if is_first_worker():
                print(*batch)
                assert 1 == 0

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
                acc = acc.float().mean()
                #print('???',acc)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                acc = acc / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()

            tr_loss += loss.item()
            tr_acc += acc.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if is_first_worker(
                ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    if 'fairseq' not in args.train_model_type:
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                    else:
                        torch.save(model.state_dict(),
                                   os.path.join(output_dir, 'model.pt'))

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)
                dist.barrier()

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.evaluate_during_training and global_step % (
                            args.logging_steps_per_eval *
                            args.logging_steps) == 0:
                        model.eval()
                        reranking_mrr, full_ranking_mrr = passage_dist_eval(
                            args, model, tokenizer)
                        if is_first_worker():
                            print("Reranking/Full ranking mrr: {0}/{1}".format(
                                str(reranking_mrr), str(full_ranking_mrr)))
                            mrr_dict = {
                                "reranking": float(reranking_mrr),
                                "full_raking": float(full_ranking_mrr)
                            }
                            tb_writer.add_scalars("mrr", mrr_dict, global_step)
                            print(args.output_dir)

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    acc_scalar = (tr_acc - logging_acc) / args.logging_steps
                    logs["acc"] = acc_scalar
                    logging_acc = tr_acc

                    if is_first_worker():
                        for key, value in logs.items():
                            print(key, type(value))
                            tb_writer.add_scalar(key, value, global_step)
                        tb_writer.add_scalar("epoch", m_epoch, global_step)
                        print(json.dumps({**logs, **{"step": global_step}}))
                    dist.barrier()

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 7
0
def train(args, model, tokenizer, f, train_fn):
    """ Train the model """
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs

    print('????t_total', t_total)
    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3", "embeddingHead"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    # if getattr_recursive(model, "roberta.encoder.layer") is not None:
    #     for layer in model.roberta.encoder.layer:
    #         optimizer_grouped_parameters.append({"params": layer.parameters()})
    #         for p in layer.parameters():
    #             layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    if args.scheduler.lower() == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    elif args.scheduler.lower() == "cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception(
            "Scheduler {0} not recognized! Can only be linear or cosine".
            format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(
                    args.model_name_or_path,
                    "scheduler.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                args.expected_train_size // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    #print('???',args.local_rank)
    #assert 1==0, "?????"
    for m_epoch in train_iterator:
        f.seek(0)
        sds = StreamingDataset(f, train_fn)
        epoch_iterator = DataLoader(sds,
                                    batch_size=args.per_gpu_train_batch_size,
                                    num_workers=1)
        count = 0
        avg_cls_norm = 0
        loss_avg = 0
        for step, batch in tqdm(enumerate(epoch_iterator),
                                desc="Iteration",
                                disable=args.local_rank not in [-1, 0]):
            #assert 1==0, "?????"
            # Skip past any already trained steps if resuming training
            #assert 1==0, steps_trained_in_current_epoch
            # if not args.reset_iter:
            #     if steps_trained_in_current_epoch > 0:
            #         steps_trained_in_current_epoch -= 1
            #         continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)
            # print('???',*batch)
            # assert 1==0, "!!!!!"
            with torch.no_grad():
                outputs = model(*batch)
                cls_norm = outputs[1]
                loss = outputs[0]

                count += 1
                avg_cls_norm += float(cls_norm.cpu().data)
                loss_avg += float(loss.cpu().data)
                print(
                    "SEED-Encoder norm: ",
                    cls_norm,
                )
                #print("loss: ",loss)
                #assert 1==0
                #print("optimus norm: ",cls_norm)

            if count == 1024:
                # print('avg_cls_norm: ',float(avg_cls_norm)/count)
                print('avg_cls_sim: ', float(avg_cls_norm) / count)
                print('avg_loss: ', float(loss_avg) / count)
                return

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 8
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'aae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('aae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   pin_memory=True)

    pointnet = config.get('pointnet', False)
    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)

    if pointnet:
        from models.pointnet import PointNet
        encoder = PointNet(config).to(device)
        # PointNet initializes it's own weights during instance creation
    else:
        encoder = aae.Encoder(config).to(device)
        encoder.apply(weights_init)

    discriminator = aae.Discriminator(config).to(device)

    hyper_network.apply(weights_init)
    discriminator.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        if pointnet:
            from utils.metrics import chamfer_distance
            reconstruction_loss = chamfer_distance
        else:
            from losses.champfer_loss import ChamferLoss
            reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from utils.metrics import earth_mover_distance
        reconstruction_loss = earth_mover_distance
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(chain(encoder.parameters(), hyper_network.parameters()),
                                    **config['optimizer']['E_HN']['hyperparams'])

    discriminator_optimizer = getattr(optim, config['optimizer']['D']['type'])
    discriminator_optimizer = discriminator_optimizer(discriminator.parameters(),
                                                      **config['optimizer']['D']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_E.pth')))
        discriminator.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_D.pth')))

        e_hn_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        discriminator_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_Do.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_g = np.load(join(metrics_path, f'{starting_epoch - 1:05}_G.npy')).tolist()
        losses_eg = np.load(join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
        losses_d = np.load(join(metrics_path, f'{starting_epoch - 1:05}_D.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_g = []
        losses_eg = []
        losses_d = []

    normalize_points = config['target_network_input']['normalization']['enable']
    if normalize_points:
        normalization_type = config['target_network_input']['normalization']['type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder.train()
        discriminator.train()

        total_loss_all = 0.0
        total_loss_reconstruction = 0.0
        total_loss_encoder = 0.0
        total_loss_discriminator = 0.0
        total_loss_regularization = 0.0
        for i, point_data in enumerate(points_dataloader, 1):

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            if pointnet:
                _, feature_transform, codes = encoder(X)
            else:
                codes, _, _ = encoder(X)

            # discriminator training
            noise = torch.empty(codes.shape[0], config['z_size']).normal_(mean=config['normal_mu'],
                                                                          std=config['normal_std']).to(device)
            synth_logit = discriminator(codes)
            real_logit = discriminator(noise)
            if config.get('wasserstein', True):
                loss_discriminator = torch.mean(synth_logit) - torch.mean(real_logit)

                alpha = torch.rand(codes.shape[0], 1).to(device)
                differences = codes - noise
                interpolates = noise + alpha * differences
                disc_interpolates = discriminator(interpolates)

                # gradient_penalty_function
                gradients = grad(
                    outputs=disc_interpolates,
                    inputs=interpolates,
                    grad_outputs=torch.ones_like(disc_interpolates).to(device),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)[0]
                slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
                gradient_penalty = ((slopes - 1) ** 2).mean()
                loss_gp = config['gradient_penalty_coef'] * gradient_penalty
                loss_discriminator += loss_gp
            else:
                # An alternative is a = -1, b = 1 iff c = 0
                a = 0.0
                b = 1.0
                loss_discriminator = 0.5 * ((real_logit - b)**2 + (synth_logit - a)**2)

            discriminator_optimizer.zero_grad()
            discriminator.zero_grad()

            loss_discriminator.backward(retain_graph=True)
            total_loss_discriminator += loss_discriminator.item()
            discriminator_optimizer.step()

            # hyper network training
            target_networks_weights = hyper_network(codes)

            X_rec = torch.zeros(X.shape).to(device)
            for j, target_network_weights in enumerate(target_networks_weights):
                target_network = aae.TargetNetwork(config, target_network_weights).to(device)

                if not config['target_network_input']['constant'] or target_network_input is None:
                    target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1]))

                X_rec[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1)

            if pointnet:
                loss_reconstruction = config['reconstruction_coef'] * \
                                      reconstruction_loss(torch.transpose(X, 1, 2).contiguous(),
                                                          torch.transpose(X_rec, 1, 2).contiguous(),
                                                          batch_size=X.shape[0]).mean()
            else:
                loss_reconstruction = torch.mean(
                    config['reconstruction_coef'] *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

            # encoder training
            synth_logit = discriminator(codes)
            if config.get('wasserstein', True):
                loss_encoder = -torch.mean(synth_logit)
            else:
                # An alternative is c = 0 iff a = -1, b = 1
                c = 1.0
                loss_encoder = 0.5 * (synth_logit - c)**2

            if pointnet:
                regularization_loss = config['feature_regularization_coef'] * \
                                      feature_transform_regularization(feature_transform).mean()
                loss_all = loss_reconstruction + loss_encoder + regularization_loss
            else:
                loss_all = loss_reconstruction + loss_encoder

            e_hn_optimizer.zero_grad()
            encoder.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_reconstruction += loss_reconstruction.item()
            total_loss_encoder += loss_encoder.item()
            total_loss_all += loss_all.item()

            if pointnet:
                total_loss_regularization += regularization_loss.item()

        log.info(
            f'[{epoch}/{config["max_epochs"]}] '
            f'Total_Loss: {total_loss_all / i:.4f} '
            f'Loss_R: {total_loss_reconstruction / i:.4f} '
            f'Loss_E: {total_loss_encoder / i:.4f} '
            f'Loss_D: {total_loss_discriminator / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )

        if pointnet:
            log.info(f'Loss_Regularization: {total_loss_regularization / i:.4f}')

        losses_e.append(total_loss_reconstruction)
        losses_g.append(total_loss_encoder)
        losses_eg.append(total_loss_all)
        losses_d.append(total_loss_discriminator)

        #
        # Save intermediate results
        #
        if epoch % config['save_samples_frequency'] == 0:
            log.debug('Saving samples...')

            X = X.cpu().numpy()
            X_rec = X_rec.detach().cpu().numpy()

            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False,
                                          title=str(epoch))
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False)
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_weights_frequency'] == 0:
            log.debug('Saving weights and losses...')

            torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))
            torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
            torch.save(discriminator.state_dict(), join(weights_path, f'{epoch:05}_D.pth'))
            torch.save(discriminator_optimizer.state_dict(), join(weights_path, f'{epoch:05}_Do.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_G'), np.array(losses_g))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
            np.save(join(metrics_path, f'{epoch:05}_D'), np.array(losses_d))
Exemplo n.º 9
0
def train():
    config_print()
    print("SEED : {}".format(GLOBAL_SEED))
    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_ids
    set_seed(GLOBAL_SEED)
    best_prec1 = 0.
    write_log = 'logs/%s' % config.dataset_tag + config.gpu_ids
    write_val_log = 'logs/val%s' % config.dataset_tag + config.gpu_ids
    write = SummaryWriter(log_dir=write_log)
    write_val = SummaryWriter(log_dir=write_val_log)
    data_config = getDatasetConfig(config.dataset_tag)

    #load dataset
    train_dataset = CustomDataset(data_config['train'],
                                  data_config['train_root'],
                                  True)  #txt.file,train_root_dir,is_traning
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.workers,
                              pin_memory=True,
                              worker_init_fn=_init_fn)
    val_dataset = CustomDataset(data_config['val'], data_config['val_root'],
                                False)
    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.workers,
                            pin_memory=True)  #,worker_init_fn=_init_fn)

    print('Dataset Name:{dataset_name}, Train:[{train_num}], Val:[{val_num}]'.
          format(dataset_name=config.dataset_tag,
                 train_num=len(train_dataset),
                 val_num=len(val_dataset)))

    # define model

    net = init_model(pretrained=True,
                     model_name=config.model_name,
                     class_num=config.class_num)

    # gup config
    use_gpu = torch.cuda.is_available() and config.use_gpu
    if use_gpu:
        net = net.cuda()
    gpu_ids = [int(r) for r in config.gpu_ids.split(',')]
    if use_gpu and config.multi_gpu:
        net = torch.nn.DataParallel(net, device_ids=gpu_ids)

    # define potimizer
    assert config.optimizer in ['sgd', 'adam'], 'optim name not found!'
    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=config.learning_rate,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optimizer == 'adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=config.learning_rate,
                                     weight_decay=config.weight_decay)

    # define learning scheduler
    assert config.scheduler in ['plateau', 'step', 'muilt_step',
                                'cosine'], 'scheduler not supported!!!'
    if config.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               patience=3,
                                                               factor=0.1)
    elif config.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=2,
                                                    gamma=0.9)
    elif config.scheduler == 'muilt_step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=[30, 100],
                                                         gamma=0.1)
    elif config.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs)

    # define loss
    criterion = torch.nn.CrossEntropyLoss()

    if use_gpu:
        criterion = criterion.cuda()
        # train val parameters dict
    state = {
        'model': net,
        'train_loader': train_loader,
        'val_loader': val_loader,
        'criterion': criterion,
        'config': config,
        'optimizer': optimizer,
        'write': write,
        'write_val': write_val
    }
    # define resume
    start_epoch = 0
    if config.resume:
        ckpt = torch.load(config.resume)
        net.load_state_dict(ckpt['state_dict'])
        start_epoch = ckpt['epoch']
        best_prec1 = ckpt['best_prec1']
        optimizer.load_state_dict(ckpt['optimizer'])

        # train and val
    engine = Engine()
    for e in range(start_epoch, config.epochs + 1):
        if config.scheduler in ['step', 'muilt_step']:
            scheduler.step()
        lr_train = get_lr(optimizer)
        print("Start epoch %d ==========,lr=%f" % (e, lr_train))
        train_prec, train_loss = engine.train(state, e)
        prec1, val_loss = engine.validate(state, e)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': e + 1,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, config.checkpoint_path)
        write.add_scalars("Accurancy", {'train': train_prec, 'val': prec1}, e)
        write.add_scalars("Loss", {'train': train_loss, 'val': val_loss}, e)
        if config.scheduler == 'plateau':
            scheduler.step(val_loss)
Exemplo n.º 10
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'vae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)

    hyper_network.apply(weights_init)
    encoder_pocket.apply(weights_init)
    encoder_visible.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(
        chain(encoder_visible.parameters(), encoder_pocket.parameters(),
              hyper_network.parameters()),
        **config['optimizer']['E_HN']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder_pocket.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth')))
        encoder_visible.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth')))

        e_hn_optimizer.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path,
                                f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_kld = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist()
        losses_eg = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_kld = []
        losses_eg = []

    if config['target_network_input']['normalization']['enable']:
        normalization_type = config['target_network_input']['normalization'][
            'type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder_visible.train()
        encoder_pocket.train()

        total_loss_all = 0.0
        total_loss_r = 0.0
        total_loss_kld = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            # get only visible part of point cloud
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get not visible part of point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)

            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)

            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)

                if not config['target_network_input'][
                        'constant'] or target_network_input is None:
                    target_network_input = generate_points(
                        config=config,
                        epoch=epoch,
                        size=(X_whole.shape[2], X_whole.shape[1]))

                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_r = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_all = loss_r + loss_kld
            e_hn_optimizer.zero_grad()
            encoder_visible.zero_grad()
            encoder_pocket.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_r += loss_r.item()
            total_loss_kld += loss_kld.item()
            total_loss_all += loss_all.item()

        log.info(f'[{epoch}/{config["max_epochs"]}] '
                 f'Loss_ALL: {total_loss_all / i:.4f} '
                 f'Loss_R: {total_loss_r / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} '
                 f'Time: {datetime.now() - start_epoch_time}')

        losses_e.append(total_loss_r)
        losses_kld.append(total_loss_kld)
        losses_eg.append(total_loss_all)

        #
        # Save intermediate results
        #
        X = X.cpu().numpy()
        X_whole = X_whole.cpu().numpy()
        X_rec = X_rec.detach().cpu().numpy()

        if epoch % config['save_frequency'] == 0:
            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X_whole[k][0],
                                          X_whole[k][1],
                                          X_whole[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0],
                                          X[k][1],
                                          X[k][2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_visible.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_frequency'] == 0:
            log.debug('Saving data...')

            torch.save(hyper_network.state_dict(),
                       join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder_visible.state_dict(),
                       join(weights_path, f'{epoch:05}_EV.pth'))
            torch.save(encoder_pocket.state_dict(),
                       join(weights_path, f'{epoch:05}_EP.pth'))
            torch.save(e_hn_optimizer.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_KLD'),
                    np.array(losses_kld))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
Exemplo n.º 11
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    def optimizer_to(optim, device):
        for param in optim.state.values():
            # Not sure there are any global tensors in the state dict
            if isinstance(param, torch.Tensor):
                param.data = param.data.to(device)
                if param._grad is not None:
                    param._grad.data = param._grad.data.to(device)
            elif isinstance(param, dict):
                for subparam in param.values():
                    if isinstance(subparam, torch.Tensor):
                        subparam.data = subparam.data.to(device)
                        if subparam._grad is not None:
                            subparam._grad.data = subparam._grad.data.to(
                                device)

    torch.cuda.empty_cache()

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"),
                       map_location='cpu'))
    optimizer_to(optimizer, args.device)

    model.to(args.device)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model
        # path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    is_hypersphere_training = (args.hyper_align_weight > 0
                               or args.hyper_unif_weight > 0)
    if is_hypersphere_training:
        logger.info(
            f"training with hypersphere property regularization, align weight {args.hyper_align_weight}, unif weight {args.hyper_unif_weight}"
        )
    if not args.dual_training:
        args.dual_loss_weight = 0.0

    tr_loss_dict = {}

    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    # dev_ndcg = 0
    step = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)
        if os.path.isfile(os.path.join(
                args.model_name_or_path,
                "scheduler.pt")) and args.load_optimizer_scheduler:
            # Load in optimizer and scheduler states
            scheduler.load_state_dict(
                torch.load(
                    os.path.join(args.model_name_or_path, "scheduler.pt")))

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(
                args.ann_dir, is_grouped=(args.grouping_ann_data > 0))
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)

                time.sleep(30)  # wait until transmission finished

                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                # marcodev_ndcg = ndcg_json['marcodev_ndcg']
                logging.info(f"loading:\n{ndcg_json}")
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info(
                    "Total ann queries: %d",
                    len(ann_training_data) if args.grouping_ann_data < 0 else
                    len(ann_training_data) * args.grouping_ann_data)

                if args.grouping_ann_data > 0:
                    if args.polling_loaded_data_batch_from_group:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetGroupedTrainingDataProcessingFn_polling(
                                args, query_cache, passage_cache))
                    else:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetGroupedTrainingDataProcessingFn_origin(
                                args, query_cache, passage_cache))
                else:
                    if not args.dual_training:
                        if args.triplet:
                            train_dataset = StreamingDataset(
                                ann_training_data,
                                GetTripletTrainingDataProcessingFn(
                                    args, query_cache, passage_cache))
                        else:
                            train_dataset = StreamingDataset(
                                ann_training_data,
                                GetTrainingDataProcessingFn(
                                    args, query_cache, passage_cache))
                    else:
                        # return quadruplet
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetQuadrapuletTrainingDataProcessingFn(
                                args, query_cache, passage_cache))

                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data)
                        if args.grouping_ann_data < 0 else
                        len(ann_training_data) * args.grouping_ann_data)

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    for key in ndcg_json:
                        if "marcodev" in key:
                            tb_writer.add_scalar(key, ndcg_json[key],
                                                 ann_checkpoint_no)

                    if 'trec2019_ndcg' in ndcg_json:
                        tb_writer.add_scalar("trec2019_ndcg",
                                             ndcg_json['trec2019_ndcg'],
                                             ann_checkpoint_no)

                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        # original way
        if args.grouping_ann_data <= 0:
            batch = tuple(t.to(args.device) for t in batch)
            if args.triplet:
                inputs = {
                    "query_ids": batch[0].long(),
                    "attention_mask_q": batch[1].long(),
                    "input_ids_a": batch[3].long(),
                    "attention_mask_a": batch[4].long(),
                    "input_ids_b": batch[6].long(),
                    "attention_mask_b": batch[7].long()
                }
                if args.dual_training:
                    inputs["neg_query_ids"] = batch[9].long()
                    inputs["attention_mask_neg_query"] = batch[10].long()
                    inputs["prime_loss_weight"] = args.prime_loss_weight
                    inputs["dual_loss_weight"] = args.dual_loss_weight
            else:
                inputs = {
                    "input_ids_a": batch[0].long(),
                    "attention_mask_a": batch[1].long(),
                    "input_ids_b": batch[3].long(),
                    "attention_mask_b": batch[4].long(),
                    "labels": batch[6]
                }
        else:
            # the default collate_fn will convert item["q_pos"] into batch format ... I guess
            inputs = {
                "query_ids": batch["q_pos"][0].to(args.device).long(),
                "attention_mask_q": batch["q_pos"][1].to(args.device).long(),
                "input_ids_a": batch["d_pos"][0].to(args.device).long(),
                "attention_mask_a": batch["d_pos"][1].to(args.device).long(),
                "input_ids_b": batch["d_neg"][0].to(args.device).long(),
                "attention_mask_b": batch["d_neg"][1].to(args.device).long(),
            }
            if args.dual_training:
                inputs["neg_query_ids"] = batch["q_neg"][0].to(
                    args.device).long()
                inputs["attention_mask_neg_query"] = batch["q_neg"][1].to(
                    args.device).long()
                inputs["prime_loss_weight"] = args.prime_loss_weight
                inputs["dual_loss_weight"] = args.dual_loss_weight

        inputs["temperature"] = args.temperature
        inputs["loss_objective"] = args.loss_objective_function

        if is_hypersphere_training:
            inputs["alignment_weight"] = args.hyper_align_weight
            inputs["uniformity_weight"] = args.hyper_unif_weight

        step += 1

        if args.local_rank != -1:
            # sync gradients only at gradient accumulation step
            if step % args.gradient_accumulation_steps == 0:
                outputs = model(**inputs)
            else:
                with model.no_sync():
                    outputs = model(**inputs)
        else:
            outputs = model(**inputs)
        # model outputs are always tuple in transformers (see doc)
        loss = outputs[0]

        loss_item_dict = outputs[1]

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
            for k in loss_item_dict:
                loss_item_dict[k] = loss_item_dict[k].mean()

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps
            for k in loss_item_dict:
                loss_item_dict[
                    k] = loss_item_dict[k] / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if args.local_rank != -1:
                if step % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()
            else:
                loss.backward()

        def incremental_tr_loss(tr_loss_dict, loss_item_dict, total_loss):
            for k in loss_item_dict:
                if k not in tr_loss_dict:
                    tr_loss_dict[k] = loss_item_dict[k].item()
                else:
                    tr_loss_dict[k] += loss_item_dict[k].item()
            if "loss_total" not in tr_loss_dict:
                tr_loss_dict["loss_total"] = total_loss.item()
            else:
                tr_loss_dict["loss_total"] += total_loss.item()
            return tr_loss_dict

        tr_loss_dict = incremental_tr_loss(tr_loss_dict,
                                           loss_item_dict,
                                           total_loss=loss)

        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                learning_rate_scalar = scheduler.get_lr()[0]

                logs["learning_rate"] = learning_rate_scalar

                for k in tr_loss_dict:
                    logs[k] = tr_loss_dict[k] / args.logging_steps
                tr_loss_dict = {}

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
Exemplo n.º 12
0
def do_train(cfg, model):
    # get criterion -----------------------------
    criterion = criterion_factory.get_criterion(cfg)

    # get optimization --------------------------
    optimizer = optimizer_factory.get_optimizer(model, cfg)

    # initial -----------------------------------
    best = {
        'loss': float('inf'),
        'score': 0.0,
        'epoch': -1,
    }

    # resume model ------------------------------
    if cfg.resume_from:
        log.info('\n')
        log.info(f're-load model from {cfg.resume_from}')
        detail = util.load_model(cfg.resume_from, model, optimizer, cfg.device)
        best.update({
            'loss': detail['loss'],
            'score': detail['score'],
            'epoch': detail['epoch'],
        })

    # scheduler ---------------------------------
    scheduler = scheduler_factory.get_scheduler(cfg, optimizer, best['epoch'])

    # fp16 --------------------------------------
    if cfg.apex:
        amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

    # setting dataset ---------------------------
    loader_train = dataset_factory.get_dataloader(cfg.data.train)
    loader_valid = dataset_factory.get_dataloader(cfg.data.valid)

    # start trainging ---------------------------
    start_time = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
    log.info('\n')
    log.info(f'** start train [fold{cfg.fold}th] {start_time} **\n')
    log.info(
        'epoch    iter      rate     | smooth_loss/score | valid_loss/score | best_epoch/best_score |  min'
    )
    log.info(
        '-------------------------------------------------------------------------------------------------'
    )

    for epoch in range(best['epoch'] + 1, cfg.epoch):
        end = time.time()
        util.set_seed(epoch)

        ## train model --------------------------
        train_results = run_nn(cfg.data.train,
                               'train',
                               model,
                               loader_train,
                               criterion=criterion,
                               optimizer=optimizer,
                               apex=cfg.apex,
                               epoch=epoch)

        ## valid model --------------------------
        with torch.no_grad():
            val_results = run_nn(cfg.data.valid,
                                 'valid',
                                 model,
                                 loader_valid,
                                 criterion=criterion,
                                 epoch=epoch)

        detail = {
            'score': val_results['score'],
            'loss': val_results['loss'],
            'epoch': epoch,
        }

        if val_results['loss'] <= best['loss']:
            best.update(detail)
            util.save_model(model, optimizer, detail, cfg.fold[0],
                            os.path.join(cfg.workdir, 'checkpoint'))


        log.info('%5.1f   %5d    %0.6f   |  %0.4f  %0.4f  |  %0.4f  %6.4f |  %6.1f     %6.4f    | %3.1f min' % \
                (epoch+1, len(loader_train), util.get_lr(optimizer), train_results['loss'], train_results['score'], val_results['loss'], val_results['score'], best['epoch'], best['score'], (time.time() - end) / 60))

        scheduler.step(
            val_results['loss'])  # if scheduler is reducelronplateau
        # scheduler.step()

        # early stopping-------------------------
        if cfg.early_stop:
            if epoch - best['epoch'] > cfg.early_stop:
                log.info(f'=================================> early stopping!')
                break
        time.sleep(0.01)
Exemplo n.º 13
0
def train(args):

    set_seed(args.seed)

    device = torch.device(
        'cuda' if torch.cuda.is_available() and args.gpu else 'cpu')

    batch_size = args.batch_size
    max_length = args.max_length
    mtl = args.mtl
    learning_rate = args.learning_rate

    # Defining CrossEntropyLoss as default
    criterion = nn.CrossEntropyLoss(ignore_index=constants.PAD_IDX)
    clipping = args.gradient_clipping

    #train_source_files = ["data/ordering/train.src", "data/structing/train.src", "data/lexicalization/train.src"]
    #train_target_files = ["data/ordering/train.trg", "data/structing/train.trg", "data/lexicalization/train.trg"]
    #dev_source_files = ["data/ordering/dev.src", "data/structing/dev.src", "data/lexicalization/dev.src"]
    #dev_target_files = ["data/ordering/dev.trg", "data/structing/dev.trg", "data/lexicalization/dev.trg"]

    if len(args.train_source) != len(args.train_target):
        print("Error.Number of inputs in train are not the same")
        return

    if len(args.dev_source) != len(args.dev_target):
        print("Error: Number of inputs in dev are not the same")
        return

    print("Building Encoder vocabulary")
    source_vocabs = build_vocab(args.train_source,
                                args.src_vocab,
                                save_dir=args.save_dir)
    print("Building Decoder vocabulary")
    target_vocabs = build_vocab(args.train_target,
                                args.tgt_vocab,
                                mtl=mtl,
                                name="tgt",
                                save_dir=args.save_dir)

    # source_vocabs, target_vocabs = build_vocab(args.train_source, args.train_target, mtl=mtl)

    print("Building training set and dataloaders")
    train_loaders = build_dataset(args.train_source, args.train_target, batch_size, \
      source_vocabs=source_vocabs, target_vocabs=target_vocabs, shuffle=True, mtl=mtl, max_length=max_length)
    for train_loader in train_loaders:
        print(
            f'Train - {len(train_loader):d} batches with size: {batch_size:d}')

    print("Building dev set and dataloaders")
    dev_loaders = build_dataset(args.dev_source, args.dev_target, batch_size, \
      source_vocabs=source_vocabs, target_vocabs=target_vocabs, mtl=mtl, max_length=max_length)
    for dev_loader in dev_loaders:
        print(f'Dev - {len(dev_loader):d} batches with size: {batch_size:d}')

    print("Building model")
    model = build_model(args, source_vocabs[0], target_vocabs[0], device,
                        max_length)
    print(
        f'The Transformer has {count_parameters(model):,} trainable parameters'
    )
    print(
        f'The Encoder has {count_parameters(model.encoder):,} trainable parameters'
    )
    print(
        f'The Decoder has {count_parameters(model.decoder):,} trainable parameters'
    )

    # Default optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    task_id = 0
    print_loss_total = 0  # Reset every print_every

    n_tasks = len(train_loaders)
    best_valid_loss = [float('inf') for _ in range(n_tasks)]

    for _iter in range(1, args.steps + 1):

        train_loss = train_step(model,
                                train_loaders[task_id],
                                optimizer,
                                criterion,
                                clipping,
                                device,
                                task_id=task_id)
        print_loss_total += train_loss

        if _iter % args.print_every == 0:
            print_loss_avg = print_loss_total / args.print_every
            print_loss_total = 0
            print(
                f'Task: {task_id:d} | Step: {_iter:d} | Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}'
            )

        if _iter % args.eval_steps == 0:
            print("Evaluating...")
            valid_loss = evaluate(model,
                                  dev_loaders[task_id],
                                  criterion,
                                  device,
                                  task_id=task_id)
            print(
                f'Task: {task_id:d} | Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}'
            )
            if valid_loss < best_valid_loss[task_id]:
                print(
                    f'The loss decreased from {best_valid_loss[task_id]:.3f} to {valid_loss:.3f} in the task {task_id}... saving checkpoint'
                )
                best_valid_loss[task_id] = valid_loss
                torch.save(model.state_dict(), args.save_dir + 'model.pt')
                print("Saved model.pt")

            if n_tasks > 1:
                print("Changing to the next task ...")
                task_id = (0 if task_id == n_tasks - 1 else task_id + 1)

    model.load_state_dict(torch.load(args.save_dir + 'model.pt'))

    print("Evaluating and testing")
    for index, eval_name in enumerate(args.eval):
        n = len(eval_name.split("/"))
        name = eval_name.split("/")[n - 1]
        print(f'Reading {eval_name}')
        fout = open(args.save_dir + name + "." + str(index) + ".out", "w")
        with open(eval_name, "r") as f:
            for sentence in f:
                output = translate_sentence(model, index, sentence,
                                            source_vocabs[0],
                                            target_vocabs[index], device,
                                            max_length)
                fout.write(output.replace("<eos>", "").strip() + "\n")
        fout.close()

    for index, test_name in enumerate(args.test):
        n = len(test_name.split("/"))
        name = test_name.split("/")[n - 1]
        print(f'Reading {test_name}')
        fout = open(args.save_dir + name + "." + str(index) + ".out", "w")
        with open(test_name, "r") as f:
            for sentence in f:
                output = translate_sentence(model, index, sentence,
                                            source_vocabs[0],
                                            target_vocabs[index], device,
                                            max_length)
                fout.write(output.replace("<eos>", "").strip() + "\n")
        fout.close()
Exemplo n.º 14
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(
        1, args.n_gpu)  #nll loss for query
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer = get_optimizer(
        args,
        model,
        weight_decay=args.weight_decay,
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0
    iter_count = 0

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps)

    global_step = 0
    if args.model_name_or_path != "bert-base-uncased":
        saved_state = load_states_from_checkpoint(args.model_name_or_path)
        global_step = _load_saved_state(model, optimizer, scheduler,
                                        saved_state)
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

        #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
        #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
        #if is_first_worker():
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
    print(args.num_epoch)
    #step = global_step
    print(step, args.max_steps, global_step)

    global_step = 0
    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:

            if args.num_epoch == 0:
                #print('yes')
                # check if new ann training data is availabe
                ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
                #print(ann_path)
                #print(ann_no)
                #print(ndcg_json)
                if ann_path is not None and ann_no != last_ann_no:
                    logger.info("Training on new add data at %s", ann_path)
                    time.sleep(180)
                    with open(ann_path, 'r') as f:
                        #print(ann_path)
                        ann_training_data = f.readlines()
                    logger.info("Training data line count: %d",
                                len(ann_training_data))
                    ann_training_data = [
                        l for l in ann_training_data
                        if len(l.split('\t')[2].split(',')) > 1
                    ]
                    logger.info("Filtered training data line count: %d",
                                len(ann_training_data))
                    #ann_checkpoint_path = ndcg_json['checkpoint']
                    #ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                    aligned_size = (len(ann_training_data) //
                                    args.world_size) * args.world_size
                    ann_training_data = ann_training_data[:aligned_size]

                    logger.info("Total ann queries: %d",
                                len(ann_training_data))
                    if args.triplet:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTripletTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset, batch_size=args.train_batch_size)
                    else:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset,
                            batch_size=args.train_batch_size * 2)
                    train_dataloader_iter = iter(train_dataloader)

                    # re-warmup
                    if not args.single_warmup:
                        scheduler = get_linear_schedule_with_warmup(
                            optimizer,
                            num_warmup_steps=args.warmup_steps,
                            num_training_steps=len(ann_training_data))

                    if args.local_rank != -1:
                        dist.barrier()

                    if is_first_worker():
                        # add ndcg at checkpoint step used instead of current step
                        #tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no)
                        #tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no)
                        #if 'top20_trivia' in ndcg_json:
                        #    tb_writer.add_scalar("retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no)
                        #    tb_writer.add_scalar("retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no)
                        if last_ann_no != -1:
                            tb_writer.add_scalar("epoch", last_ann_no,
                                                 global_step - 1)
                        tb_writer.add_scalar("epoch", ann_no, global_step)
                    last_ann_no = ann_no
            elif step == 0:

                train_data_path = os.path.join(args.data_dir, "train-data")
                with open(train_data_path, 'r') as f:
                    training_data = f.readlines()
                if args.triplet:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size)
                else:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size * 2)
                all_batch = [b for b in train_dataloader]
                logger.info("Total batch count: %d", len(all_batch))
                train_dataloader_iter = iter(train_dataloader)

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            if args.num_epoch != 0:
                iter_count += 1
                if is_first_worker():
                    tb_writer.add_scalar("epoch", iter_count - 1,
                                         global_step - 1)
                    tb_writer.add_scalar("epoch", iter_count, global_step)
            #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
            #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
            #if is_first_worker():
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None:
                with open(ann_path, 'r') as f:
                    print(ann_path)
                    ann_training_data = f.readlines()
                logger.info("Training data line count: %d",
                            len(ann_training_data))
                ann_training_data = [
                    l for l in ann_training_data
                    if len(l.split('\t')[2].split(',')) > 1
                ]
                logger.info("Filtered training data line count: %d",
                            len(ann_training_data))

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]
                train_dataset = StreamingDataset(
                    ann_training_data,
                    GetTrainingDataProcessingFn(args, query_cache,
                                                passage_cache))
                train_dataloader = DataLoader(
                    train_dataset, batch_size=args.train_batch_size * 2)

            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)
            dist.barrier()

        if args.num_epoch != 0 and iter_count > args.num_epoch:
            break

        step += 1
        if args.triplet:
            loss = triplet_fwd_pass(args, model, batch)
        else:
            loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch)

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                _save_checkpoint(args, model, optimizer, scheduler,
                                 global_step)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
def train(args):

    set_seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')

    batch_size = args.batch_size
    max_length = args.max_length
    mtl = args.mtl

    learning_rate = 0.0005
    if not args.learning_rate:
        learning_rate = args.learning_rate

    if len(args.train_source) != len(args.train_target):
        print("Error.Number of inputs in train are not the same")
        return

    if len(args.dev_source) != len(args.dev_target):
        print("Error: Number of inputs in dev are not the same")
        return

    if not args.tie_embeddings:
        print("Building Encoder vocabulary")
        source_vocabs = build_vocab(args.train_source, args.src_vocab, save_dir=args.save_dir)
        print("Building Decoder vocabulary")
        target_vocabs = build_vocab(args.train_target, args.tgt_vocab, mtl=mtl, name ="tgt", save_dir=args.save_dir)
    else:
        print("Building Share vocabulary")
        source_vocabs = build_vocab(args.train_source + args.train_target, args.src_vocab, name="tied", save_dir=args.save_dir)
        if mtl:
            target_vocabs = [source_vocabs[0] for _ in range(len(args.train_target))]
        else:
            target_vocabs = source_vocabs
    print("Number of source vocabularies:", len(source_vocabs))
    print("Number of target vocabularies:", len(target_vocabs))

    save_params(args, args.save_dir + "args.json")

    # source_vocabs, target_vocabs = build_vocab(args.train_source, args.train_target, mtl=mtl)

    print("Building training set and dataloaders")
    train_loaders = build_dataset(args.train_source, args.train_target, batch_size, \
            source_vocabs=source_vocabs, target_vocabs=target_vocabs, shuffle=True, mtl=mtl, max_length=max_length)
    for train_loader in train_loaders:
        print(f'Train - {len(train_loader):d} batches with size: {batch_size:d}')

    print("Building dev set and dataloaders")
    dev_loaders = build_dataset(args.dev_source, args.dev_target, batch_size, \
            source_vocabs=source_vocabs, target_vocabs=target_vocabs, mtl=mtl, max_length=max_length)
    for dev_loader in dev_loaders:
        print(f'Dev - {len(dev_loader):d} batches with size: {batch_size:d}')

    if args.model is not None:
        print("Loading the encoder from an external model...")
        multitask_model = load_model(args, source_vocabs, target_vocabs, device, max_length)
    else:
        print("Building model")
        multitask_model = build_model(args, source_vocabs, target_vocabs, device, max_length)

    print(f'The Transformer has {count_parameters(multitask_model):,} trainable parameters')
    print(f'The Encoder has {count_parameters(multitask_model.encoder):,} trainable parameters')
    for index, decoder in enumerate(multitask_model.decoders):
        print(f'The Decoder {index+1} has {count_parameters(decoder):,} trainable parameters')


    # Defining CrossEntropyLoss as default
    #criterion = nn.CrossEntropyLoss(ignore_index = constants.PAD_IDX)
    criterions = [LabelSmoothing(size=target_vocab.len(), padding_idx=constants.PAD_IDX, smoothing=0.1) \
                                        for target_vocab in target_vocabs]

    # Default optimizer
    optimizer = torch.optim.Adam(multitask_model.parameters(), lr = learning_rate, betas=(0.9, 0.98), eps=1e-09)
    model_opts = [NoamOpt(args.hidden_size, args.warmup_steps, optimizer) for _ in target_vocabs]

    task_id = 0
    print_loss_total = 0  # Reset every print_every

    n_tasks = len(train_loaders)
    best_valid_loss = [float(0) for _ in range(n_tasks)]

    if not args.translate:
        print("Start training...")
        patience = 30
        if not args.patience:
            patience = args.patience

        if n_tasks > 1:
            print("Patience wont be taking into account in Multitask learning")

        for _iter in range(1, args.steps + 1):

            train_loss = train_step(multitask_model, train_loaders[task_id], \
                       LossCompute(criterions[task_id], model_opts[task_id]), device, task_id = task_id)

            print_loss_total += train_loss

            if _iter % args.print_every == 0:
                print_loss_avg = print_loss_total / args.print_every
                print_loss_total = 0
                print(f'Task: {task_id:d} | Step: {_iter:d} | Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')


            if _iter % args.eval_steps == 0:
                print("Evaluating...")
                accuracies = run_evaluation(multitask_model, source_vocabs[0], target_vocabs, device, args.beam_size, args.eval, args.eval_ref, max_length)
                accuracy = round(accuracies[task_id], 3)
                valid_loss = evaluate(multitask_model, dev_loaders[task_id], LossCompute(criterions[task_id], None), \
                                device, task_id=task_id)
                print(f'Task: {task_id:d} | Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f} | Acc. {accuracy:.3f}')
                if accuracy > best_valid_loss[task_id]:
                    print(f'The accuracy increased from {best_valid_loss[task_id]:.3f} to {accuracy:.3f} in the task {task_id}... saving checkpoint')
                    patience = 30
                    best_valid_loss[task_id] = accuracy
                    torch.save(multitask_model.state_dict(), args.save_dir + 'model.pt')
                    print("Saved model.pt")
                else:
                    if n_tasks == 1:
                        if patience == 0:
                            break
                        else:
                            patience -= 1

                if n_tasks > 1:
                    print("Changing to the next task ...")
                    task_id = (0 if task_id == n_tasks - 1 else task_id + 1)

    try:
        multitask_model.load_state_dict(torch.load(args.save_dir + 'model.pt'))
    except:
        print(f'There is no model in the following path {args.save_dir}')
        return

    print("Evaluating and testing")
    run_translate(multitask_model, source_vocabs[0], target_vocabs, args.save_dir, device, args.beam_size, args.eval, max_length=max_length)
    run_translate(multitask_model, source_vocabs[0], target_vocabs, args.save_dir, device, args.beam_size, args.test, max_length=max_length)
Exemplo n.º 16
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model
        # path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0

    save_no = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0 and global_step % args.save_steps < args.save_steps / 20:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)
                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                dev_ndcg = ndcg_json['ndcg']
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info("Total ann queries: %d", len(ann_training_data))
                if args.triplet:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                else:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data))

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    tb_writer.add_scalar("dev_ndcg", dev_ndcg,
                                         ann_checkpoint_no)
                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        batch = tuple(t.to(args.device) for t in batch)
        step += 1

        if args.triplet:
            inputs = {
                "query_ids": batch[0].long(),
                "attention_mask_q": batch[1].long(),
                "input_ids_a": batch[3].long(),
                "attention_mask_a": batch[4].long(),
                "input_ids_b": batch[6].long(),
                "attention_mask_b": batch[7].long()
            }
        else:
            inputs = {
                "input_ids_a": batch[0].long(),
                "attention_mask_a": batch[1].long(),
                "input_ids_b": batch[3].long(),
                "attention_mask_b": batch[4].long(),
                "labels": batch[6]
            }

        # sync gradients only at gradient accumulation step
        if step % args.gradient_accumulation_steps == 0:
            outputs = model(**inputs)
        else:
            with model.no_sync():
                outputs = model(**inputs)
        # model outputs are always tuple in transformers (see doc)
        loss = outputs[0]

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

                save_no += 1

                if save_no > 1:
                    ann_no, ann_path, ndcg_json = get_latest_ann_data(
                        args.ann_dir)
                    while (ann_no == last_ann_no):
                        print("Waiting for new ann_data. Sleeping for 1hr!!")
                        time.sleep(3600)
                        ann_no, ann_path, ndcg_json = get_latest_ann_data(
                            args.ann_dir)

            dist.barrier()

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
def main():
    tvt_pairs_dict = load_pair_tvt_splits()

    orig_dataset = load_dataset(FLAGS.dataset, 'all', FLAGS.node_feats,
                                FLAGS.edge_feats)

    orig_dataset, num_node_feat = encode_node_features(dataset=orig_dataset)
    num_interaction_edge_feat = encode_edge_features(
        orig_dataset.interaction_combo_nxgraph, FLAGS.hyper_eatts)

    for i, (train_pairs, val_pairs, test_pairs) in \
            enumerate(zip(tvt_pairs_dict['train'],
                          tvt_pairs_dict['val'],
                          tvt_pairs_dict['test'])):
        fold_num = i + 1
        if FLAGS.cross_val and FLAGS.run_only_on_fold != -1 and FLAGS.run_only_on_fold != fold_num:
            continue

        set_seed(FLAGS.random_seed + 5)
        print(f'======== FOLD {fold_num} ========')
        saver = Saver(fold=fold_num)
        dataset = deepcopy(orig_dataset)
        train_data, val_data, test_data, val_pairs, test_pairs, _ = \
            load_pairs_to_dataset(num_node_feat, num_interaction_edge_feat,
                                  train_pairs, val_pairs, test_pairs,
                                  dataset)
        print('========= Training... ========')
        if FLAGS.load_model is not None:
            print('loading models: {}'.format(FLAGS.load_model))
            trained_model = Model(train_data).to(FLAGS.device)
            trained_model.load_state_dict(torch.load(
                FLAGS.load_model, map_location=FLAGS.device),
                                          strict=False)
            print('models loaded')
            print(trained_model)
        else:
            train(train_data, val_data, val_pairs, saver, fold_num=fold_num)
            trained_model = saver.load_trained_model(train_data)
            if FLAGS.save_model:
                saver.save_trained_model(trained_model)

        print('======== Testing... ========')

        if FLAGS.lower_level_layers and FLAGS.higher_level_layers:
            _get_initial_embd(test_data, trained_model)
            test_data.dataset.init_interaction_graph_embds(device=FLAGS.device)
        elif FLAGS.higher_level_layers and not FLAGS.lower_level_layers:
            test_data.dataset.init_interaction_graph_embds(device=FLAGS.device)

        if FLAGS.save_final_node_embeddings and 'gmn' not in FLAGS.model:
            with torch.no_grad():
                trained_model = trained_model.to(FLAGS.device)
                trained_model.eval()
                if FLAGS.higher_level_layers:
                    batch_data = model_forward(trained_model,
                                               test_data,
                                               is_train=False)
                    trained_model.use_layers = "higher_no_eval_layers"
                    outs = trained_model(batch_data)
                else:
                    outs = _get_initial_embd(test_data, trained_model)
                    trained_model.use_layers = 'all'

            saver.save_graph_embeddings_mat(outs.cpu().detach().numpy(),
                                            test_data.dataset.id_map,
                                            test_data.dataset.gs_map)
            if FLAGS.higher_level_layers:
                batch_data.restore_interaction_nxgraph()

        test(trained_model, test_data, test_pairs, saver, fold_num)
        overall_time = convert_long_time_to_str(time() - t)
        print(overall_time)
        print(saver.get_log_dir())
        print(basename(saver.get_log_dir()))
        saver.save_overall_time(overall_time)
        saver.close()
    if FLAGS.cross_val and COMET_EXPERIMENT:
        results = aggregate_comet_results_from_folds(
            COMET_EXPERIMENT, FLAGS.num_folds, FLAGS.dataset,
            FLAGS.eval_performance_by_degree)
        COMET_EXPERIMENT.log_metrics(results, prefix='aggr')
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config,
                                      config['arch'],
                                      'experiments',
                                      dirs_to_create=[
                                          'interpolations', 'sphere',
                                          'points_interpolation',
                                          'different_number_points', 'fixed',
                                          'reconstruction', 'sphere_triangles',
                                          'sphere_triangles_interpolation'
                                      ])
    weights_path = get_weights_dir(config)
    epoch = find_latest_epoch(weights_path)

    if not epoch:
        print("Invalid 'weights_path' in configuration")
        exit(1)

    setup_logging(results_dir)
    global log
    log = logging.getLogger('aae')

    if not exists(join(results_dir, 'experiment_config.json')):
        with open(join(results_dir, 'experiment_config.json'), mode='w') as f:
            json.dump(config, f)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=64,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    log.info("Weights for epoch: %s" % epoch)

    log.info("Loading weights...")
    hyper_network.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_G.pth')))
    encoder_pocket.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EP.pth')))
    encoder_visible.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EV.pth')))

    hyper_network.eval()
    encoder_visible.eval()
    encoder_pocket.eval()

    total_loss_eg = 0.0
    total_loss_e = 0.0
    total_loss_kld = 0.0
    x = []

    with torch.no_grad():
        for i, point_data in enumerate(points_dataloader, 1):
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # get visible point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            x.append(X)
            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)
            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)
                target_network_input = generate_points(config=config,
                                                       epoch=epoch,
                                                       size=(X_whole.shape[2],
                                                             X_whole.shape[1]))
                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_e = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_eg = loss_e + loss_kld
            total_loss_e += loss_e.item()
            total_loss_kld += loss_kld.item()
            total_loss_eg += loss_eg.item()

        log.info(f'Loss_ALL: {total_loss_eg / i:.4f} '
                 f'Loss_R: {total_loss_e / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} ')

        # take the lowest possible first dim
        min_dim = min(x, key=lambda X: X.shape[2]).shape[2]
        x = [X[:, :, :min_dim] for X in x]
        x = torch.cat(x)

        if config['experiments']['interpolation']['execute']:
            interpolation(
                x, encoder_pocket, hyper_network, device, results_dir, epoch,
                config['experiments']['interpolation']['amount'],
                config['experiments']['interpolation']['transitions'])

        if config['experiments']['interpolation_between_two_points'][
                'execute']:
            interpolation_between_two_points(
                encoder_pocket, hyper_network, device, x, results_dir, epoch,
                config['experiments']['interpolation_between_two_points']
                ['amount'], config['experiments']
                ['interpolation_between_two_points']['image_points'],
                config['experiments']['interpolation_between_two_points']
                ['transitions'])

        if config['experiments']['reconstruction']['execute']:
            reconstruction(encoder_pocket, hyper_network, device, x,
                           results_dir, epoch,
                           config['experiments']['reconstruction']['amount'])

        if config['experiments']['sphere']['execute']:
            sphere(encoder_pocket, hyper_network, device, x, results_dir,
                   epoch, config['experiments']['sphere']['amount'],
                   config['experiments']['sphere']['image_points'],
                   config['experiments']['sphere']['start'],
                   config['experiments']['sphere']['end'],
                   config['experiments']['sphere']['transitions'])

        if config['experiments']['sphere_triangles']['execute']:
            sphere_triangles(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles']['amount'],
                config['experiments']['sphere_triangles']['method'],
                config['experiments']['sphere_triangles']['depth'],
                config['experiments']['sphere_triangles']['start'],
                config['experiments']['sphere_triangles']['end'],
                config['experiments']['sphere_triangles']['transitions'])

        if config['experiments']['sphere_triangles_interpolation']['execute']:
            sphere_triangles_interpolation(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles_interpolation']
                ['amount'], config['experiments']
                ['sphere_triangles_interpolation']['method'],
                config['experiments']['sphere_triangles_interpolation']
                ['depth'], config['experiments']
                ['sphere_triangles_interpolation']['coefficient'],
                config['experiments']['sphere_triangles_interpolation']
                ['transitions'])

        if config['experiments']['different_number_of_points']['execute']:
            different_number_of_points(
                encoder_pocket, hyper_network, x, device, results_dir, epoch,
                config['experiments']['different_number_of_points']['amount'],
                config['experiments']['different_number_of_points']
                ['image_points'])

        if config['experiments']['fixed']['execute']:
            # get visible element from loader (probably should be done using given object for example using
            # parser

            points_dataloader = DataLoader(dataset,
                                           batch_size=10,
                                           shuffle=True,
                                           num_workers=8,
                                           drop_last=True,
                                           pin_memory=True,
                                           collate_fn=collate_fn)
            X_visible = next(iter(points_dataloader))['visible'].to(
                device, dtype=torch.float)
            X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            fixed(hyper_network, encoder_visible, X_visible, device,
                  results_dir, epoch, config['experiments']['fixed']['amount'],
                  config['z_size'] // 2,
                  config['experiments']['fixed']['mean'],
                  config['experiments']['fixed']['std'], (3, 2048),
                  config['experiments']['fixed']['triangulation']['execute'],
                  config['experiments']['fixed']['triangulation']['method'],
                  config['experiments']['fixed']['triangulation']['depth'])
Exemplo n.º 19
0
def _init_fn(worker_id):
    set_seed(GLOBAL_SEED + worker_id)