Esempio n. 1
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_path',
        default='/data3/zhangyf/cross_modal_retrieval/SCAN/data',
        help='path to datasets')
    parser.add_argument('--data_name',
                        default='f30k_precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument(
        '--vocab_path',
        default='/data3/zhangyf/cross_modal_retrieval/SCAN/vocab/',
        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=20,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--decoder_dim',
                        default=512,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=10,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=4,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=30,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument(
        '--resume',
        default=
        '/data3/zhangyf/cross_modal_retrieval/vsepp_next_train_12_31_f30k/run/coco_vse++_ft_128_f30k_next/model_best.pth.tar',
        type=str,
        metavar='PATH',
        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--reset_train',
                        action='store_true',
                        help='Ensure the training is always done in '
                        'train mode (Not recommended).')
    parser.add_argument('--finetune',
                        action='store_true',
                        help='Fine-tune the image encoder.')
    parser.add_argument('--cnn_type',
                        default='resnet152',
                        help="""The CNN used for image encoder
                        (e.g. vgg19, resnet152)""")
    parser.add_argument('--crop_size',
                        default=224,
                        type=int,
                        help='Size of an image crop as the CNN input.')

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = pickle.load(
        open(os.path.join(opt.vocab_path, '%s_vocab.pkl' % opt.data_name),
             'rb'))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = SCAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        bset_rsum = train(opt, train_loader, model, epoch, val_loader,
                          best_rsum)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Esempio n. 2
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='./data/',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=30,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=15,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument(
        '--raw_feature_norm',
        default="clipped_l2norm",
        help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--agg_func',
                        default="LogSumExp",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--cross_attn', default="t2i", help='t2i|i2t')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--bi_gru',
                        action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse',
                        default=6.,
                        type=float,
                        help='LogSumExp temp.')
    parser.add_argument('--lambda_softmax',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature.')
    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = SCAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Esempio n. 3
0
def start_experiment(opt, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    print("Let's use", torch.cuda.device_count(), "GPUs!")
    print("Number threads:", torch.get_num_threads())

    # Load Vocabulary Wrapper, create dictionary that can switch between ids and words
    vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format(
        opt.vocab_path, opt.clothing, opt.data_name, opt.version))

    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data_ken.get_loaders(opt.data_name, vocab,
                                                    opt.batch_size,
                                                    opt.workers, opt)

    # Construct the model
    model = SCAN(opt)

    # save hyperparameters in file
    save_hyperparameters(opt.logger_name, opt)

    best_rsum = 0
    start_epoch = 0
    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)
        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)

        last_epoch = False
        if epoch == (opt.num_epochs - 1):
            last_epoch = True

        # only save when best epoch, or last epoch for further training
        if is_best or last_epoch:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'best_rsum': best_rsum,
                    'opt': opt,
                    'Eiters': model.Eiters,
                },
                is_best,
                last_epoch,
                filename='checkpoint_{}.pth.tar'.format(epoch),
                prefix=opt.model_name + '/')
    return best_rsum