예제 #1
0
np.random.seed(1234)
tf.set_random_seed(1234)

# read the train args from checkpoint
param_path = args['checkpoint'].replace('.tmodel', '_params.json')
with open(param_path, 'r') as file_id:
  saved_args = json.load(file_id)

saved_args.update(args)
args = saved_args
args['preload_feats'] = False
# no supervision is needed
args['supervise_attention'] = False
args['batch_size'] = min(args['batch_size'], 10) # adjust for complex models

support.pretty_print_dict(args)

# Data files
root = args['data_root']
if args['use_refer']:
  # use refer module
  file_name = 'imdb/imdb_%s_refermrf_att.npy' % args['test_split']
else:
  file_name = 'imdb/imdb_%s.npy' % args['test_split']
imdb_path_val = os.path.join(root, file_name)

# assemblers for question and caption programs
question_assembler = Assembler(args['prog_vocab_path'])
caption_assembler = Assembler(args['prog_vocab_path'])
assemblers = {'ques': question_assembler, 'cap': caption_assembler}
예제 #2
0
def read_command_line():
    title = 'Train explicit coreference resolution visual dialog model'
    parser = argparse.ArgumentParser(description=title)
    #-------------------------------------------------------------------------

    # data input settings
    parser.add_argument('--dataset',
                        default='visdial_v0.9_tiny',
                        help='Visdial dataset type')
    parser.add_argument('--data_root',
                        default='data/',
                        help='Root to the data')
    parser.add_argument('--feature_path',
                        default='data/resnet_res5c/',
                        help='Path to the image features')
    parser.add_argument('--text_vocab_path',
                        default='',
                        help='Path to the vocabulary for text')
    parser.add_argument('--prog_vocab_path',
                        default='',
                        help='Path to the vocabulary for programs')
    parser.add_argument('--snapshot_path',
                        default='checkpoints/',
                        help='Path to save checkpoints')
    #--------------------------------------------------------------------------

    # specify encoder/decoder
    parser.add_argument('--model',
                        default='nmn-cap-prog-only',
                        help='Name of the model, will be changed later')
    parser.add_argument('--generator',
                        default='ques',
                        help='Name of the generator to use (ques | memory)')
    parser.add_argument('--decoder',
                        default='gen',
                        help='Name of the decoder to use (gen | disc)')
    parser.add_argument('--preload_features',
                        default=False,
                        type=bool,
                        help='Preload visual features on RAM')
    #-------------------------------------------------------------------------

    # model hyperparameters
    parser.add_argument('--h_feat',
                        default=14,
                        type=int,
                        help='Height of visual conv feature')
    parser.add_argument('--w_feat',
                        default=14,
                        type=int,
                        help='Width of visual conv feature')
    parser.add_argument('--d_feat',
                        default=2048,
                        type=int,
                        help='Size of visual conv feature')
    parser.add_argument('--text_embed_size',
                        default=300,
                        type=int,
                        help='Size of embedding for text')
    parser.add_argument('--map_size',
                        default=1024,
                        type=int,
                        help='Size of the final mapping')
    parser.add_argument('--prog_embed_size',
                        default=300,
                        type=int,
                        help='Size of embedding for program tokens')
    parser.add_argument('--lstm_size',
                        default=1000,
                        type=int,
                        help='Size of hidden state in LSTM')
    parser.add_argument('--enc_dropout',
                        default=True,
                        type=bool,
                        help='Dropout in encoder')
    parser.add_argument('--dec_dropout',
                        default=True,
                        type=bool,
                        help='Dropout in decoder')
    parser.add_argument('--num_layers',
                        default=2,
                        type=int,
                        help='Number of layers in LSTM')

    parser.add_argument(
        '--max_enc_len',
        default=24,
        type=int,
        help='Maximum encoding length for sentences (ques|cap)')
    parser.add_argument('--max_dec_len',
                        default=14,
                        type=int,
                        help='Maximum decoding length for programs (ques|cap)')
    parser.add_argument('--dec_sampling',
                        default=False,
                        type=bool,
                        help='Sample while decoding programs vs argmax')
    #---------------------------------------------------------------------------

    parser.add_argument('--use_refer',
                        dest='use_refer',
                        action='store_true',
                        help='Flag to use Refer for coreference resolution')
    parser.set_defaults(use_refer=False)
    parser.add_argument('--use_fact',
                        dest='use_fact',
                        action='store_true',
                        help='Flag to use the fact in coreference pool')
    parser.set_defaults(use_fact=False)
    parser.add_argument('--supervise_attention',
                        dest='supervise_attention',
                        action='store_true',
                        help='Flag to supervise attention for the modules')
    parser.set_defaults(supervise_attention=False)
    parser.add_argument('--amalgam_text_feats',
                        dest='amalgam_text_feats',
                        action='store_true',
                        help='Flag to amalgamate text features')
    parser.set_defaults(amalgam_text_feats=False)
    parser.add_argument('--no_cap_alignment',
                        dest='cap_alignment',
                        action='store_false',
                        help='Use the auxiliary caption alignment loss')
    parser.set_defaults(cap_alignment=True)
    #-------------------------------------------------------------------------

    # optimization params
    parser.add_argument(
        '--batch_size',
        default=20,
        type=int,
        help='Training batch size (adjust based on GPU memory)')
    parser.add_argument('--learning_rate',
                        default=1e-3,
                        type=float,
                        help='Learning rate for training')
    parser.add_argument('--dropout', default=0.5, type=float, help='Dropout')
    parser.add_argument('--num_epochs',
                        default=20,
                        type=int,
                        help='Maximum number of epochs to run training')
    parser.add_argument('--gpu_id',
                        type=int,
                        default=0,
                        help='GPU id to use for training, -1 for CPU')
    #-------------------------------------------------------------------------

    try:
        parsed_args = vars(parser.parse_args())
    except (IOError) as msg:
        parser.error(str(msg))

    # set the cuda environment variable for the gpu to use
    gpu_id = '' if parsed_args['gpu_id'] < 0 else str(parsed_args['gpu_id'])
    print('Using GPU id: %s' % gpu_id)
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id

    # pretty print arguments and return
    support.pretty_print_dict(parsed_args)

    return parsed_args