示例#1
0
def main():
    args = parse_args()

    mp.set_start_method('spawn')  # Using spawn is decided.
    _logger = log.get_logger(__name__, args)
    _logger.info(print_args(args))

    loaders = []
    file_list = os.listdir(args.train_file)
    random.shuffle(file_list)
    for i in range(args.worker):
        loader = data_loader.DataLoader(args.train_file,
                                        args.dict_file,
                                        separate_conj_stmt=args.direction,
                                        binary=args.binary,
                                        part_no=i,
                                        part_total=args.worker,
                                        file_list=file_list,
                                        norename=args.norename,
                                        filter_abelian=args.fabelian,
                                        compatible=args.compatible)
        loaders.append(loader)
        loader.start_reader()

    cuda_test = torch.cuda.is_available()
    cuda_tensor = torch.randn(10).cuda()

    net, mid_net, loss_fn = create_models(args, loaders[0], allow_resume=True)
    # Use fake modules to replace the real ones
    net = FakeModule(net)
    if mid_net is not None:
        mid_net = FakeModule(mid_net)
    for i in range(len(loss_fn)):
        loss_fn[i] = FakeModule(loss_fn[i])
    opt = get_opt(net, mid_net, loss_fn, args)

    inqueues = []
    outqueues = []

    plist = []
    for i in range(args.worker):
        recv_p, send_p = Pipe(False)
        recv_p2, send_p2 = Pipe(False)
        inqueues.append(send_p)
        outqueues.append(recv_p2)
        plist.append(
            Process(target=worker,
                    args=(recv_p, send_p2, loaders[i], args, i)))
        plist[-1].start()

    _logger.warning('Training begins')
    train(inqueues, outqueues, net, mid_net, loss_fn, opt, loaders, args,
          _logger)
    loader.destruct()
    for p in plist:
        p.terminate()
    for loader in loaders:
        loader.destruct()
    _logger.warning('Training ends')
示例#2
0
def main():
    parser = argparse.ArgumentParser(description='The ultimate tester')
    parser.add_argument('--model',
                        type=str,
                        help='The model file name used for testing',
                        required=True)
    parser.add_argument('--model_path',
                        type=str,
                        help='The path to model folder',
                        default='../models')
    parser.add_argument('--log',
                        type=str,
                        help='Path to log file',
                        required=True)
    parser.add_argument('--data',
                        type=str,
                        help='Path to testing set folder',
                        default='../data/hol_data/test')
    parser.add_argument('--worker',
                        type=int,
                        help='Number of workers',
                        default=4)
    parser.add_argument('--max_pair',
                        type=int,
                        help='Change max_pair settings')
    parser.add_argument('--compatible',
                        action='store_true',
                        help='Use compatible mode to run.')
    parser.add_argument('--dict_file', type=str, help='Replace dict')

    settings = parser.parse_args()
    mp.set_start_method('spawn')  # Using spawn is decided.

    _logger = log.get_logger(__name__, settings)
    _logger.info('Test program parameters')
    _logger.info(print_args(settings))
    model_path = os.path.join(settings.model_path, settings.model)
    net, mid_net, loss_fn, test_loader, args = load_model(
        model_path, settings.data, settings.compatible)
    args.test_file = settings.data
    if settings.dict_file is not None:
        args.dict_file = settings.dict_file
    if settings.max_pair is not None:
        args.max_pair = settings.max_pair
    args.compatible = settings.compatible
    inqueues = []
    outqueues = []

    if settings.worker is not None:
        args.worker = settings.worker

    plist = []
    for i in range(args.worker):
        recv_p, send_p = Pipe(False)
        recv_p2, send_p2 = Pipe(False)
        inqueues.append(send_p)
        outqueues.append(recv_p2)
        plist.append(
            Process(target=worker,
                    args=(recv_p, send_p2, test_loader, args, i)))
        plist[-1].start()

    test_loader = None

    _logger.info('Model parameters')
    _logger.info(print_args(args))

    valid_start = time.time()
    data = {}
    data['args'] = args
    data['net'] = net.state_dict()
    data['fix_net'] = False
    if mid_net is not None:
        data['mid_net'] = mid_net.state_dict()
    data['loss_fn'] = []
    for loss in loss_fn:
        data['loss_fn'].append(loss.state_dict())
    data['test'] = True

    for i in range(args.worker):
        inqueues[i].send(data)

    result_correct = 0
    result_total = 0

    _logger.warning('Test start!')
    for i in range(args.worker):
        data = outqueues[i].recv()
        result_correct += data['correct']
        result_total += data['total']
    result_ = result_correct / result_total
    _logger.warning('Validation complete! Time lapse: %.3f, Test acc: %.5f' %
                    (time.time() - valid_start, result_))

    for p in plist:
        p.terminate()
    time.sleep(5)
示例#3
0
def create_models(args, loader, allow_resume=False):
    net = model.GraphNet(loader.dict_size,
                         args.nFeats,
                         args.nSteps,
                         args.block,
                         args.module_depth,
                         args.bias,
                         args.short_cut,
                         args.direction,
                         args.loss,
                         args.binary,
                         no_step_supervision=args.no_step_supervision,
                         tied_weight=args.tied_weight,
                         compatible=args.compatible).cuda()

    mid_net = None
    if args.loss in ('mixmax', 'mixmean'):
        mid_net = model.FullyConnectedNet(args.nFeats,
                                          args.nFeats // 2,
                                          bias=args.bias).cuda()

    loss_fn = []

    for i in range(args.loss_step):
        if args.loss == 'mulloss':
            loss_fn.append(
                loss.MultiplyLoss(args.nFeats,
                                  cond_short_cut=args.cond_short_cut).cuda())
        elif args.loss == 'condloss':
            loss_fn.append(
                loss.CondLoss(args.nFeats * 2,
                              args.nFeats,
                              layer_list=args.loss_layers,
                              dropout=args.dropout,
                              cond_short_cut=args.cond_short_cut).cuda())
        elif args.loss in ('concat', 'concat_em_uc'):
            if args.uncondition or args.add_conj:
                loss_fn.append(
                    loss.ClassifyLoss(args.nFeats,
                                      args.nFeats // 2,
                                      layer_list=args.loss_layers,
                                      dropout=args.dropout,
                                      bias=args.compatible).cuda())
            else:
                loss_fn.append(
                    loss.ClassifyLoss(args.nFeats * 2,
                                      args.nFeats,
                                      layer_list=args.loss_layers,
                                      dropout=args.dropout,
                                      bias=args.compatible).cuda())
        elif args.loss in ('mixmax', 'mixmean'):
            loss_fn.append(
                loss.UCSimLoss(args.nFeats,
                               args.nFeats // 2,
                               layer_list=args.loss_layers,
                               dropout=args.dropout).cuda())
        elif args.loss == 'pair':
            loss_fn.append(
                loss.ClassifyLoss(args.nFeats // 2,
                                  args.nFeats // 4,
                                  layer_list=args.loss_layers,
                                  dropout=args.dropout).cuda())
        elif args.loss == 'em':
            loss_fn.append(
                loss.ClassifyLoss(args.nFeats,
                                  args.nFeats // 2,
                                  layer_list=args.loss_layers,
                                  dropout=args.dropout).cuda())
        else:
            assert False, 'Wrong --loss option!'

    if args.resume is not None and allow_resume:
        data = torch.load(args.resume)
        _logger = log.get_logger('Create Model', args)
        _logger.warning('Load Model!')
        _logger.info('Previous training info:')
        _logger.info('Epoch: %d  Current iter: %d  Total iter: %d',
                     data['aux']['epoch'], data['aux']['cur_iter'],
                     data['aux']['total_iter'])
        _logger.warning('Previous training args:')
        _logger.info(print_args(data['args']))
        net.load_state_dict(data['net']['state_dict'])
        if mid_net is not None:
            mid_net.load_state_dict(data['mid_net']['state_dict'])
        if not args.resume_only_net:
            for i in range(len(loss_fn)):
                loss_fn[i].load_state_dict(data['loss_fn'][i]['state_dict'])

    return net, mid_net, loss_fn
示例#4
0
import sys
import argparse
import torch

from model_utils import print_args

data = torch.load(sys.argv[1])
print('Previous training info:')
print(
    'Epoch: %d (start from 0)  Current iter: %d  Total iter: %d' %
    (data['aux']['epoch'], data['aux']['cur_iter'], data['aux']['total_iter']))
print(print_args(data['args']))