Ejemplo n.º 1
0
                    mode='gem',
                    sync_every=args.sync_every,
                    worker_optimizer=args.worker_optimizer,
                    worker_optimizer_params=args.worker_optimizer_params,
                    learning_rate=args.gem_lr,
                    momentum=args.gem_momentum,
                    kappa=args.gem_kappa)
    else:
        algo = Algo(args.optimizer,
                    loss=args.loss,
                    validate_every=validate_every,
                    sync_every=args.sync_every,
                    worker_optimizer=args.worker_optimizer,
                    worker_optimizer_params=args.worker_optimizer_params)
    if args.restore:
        algo.load(args.restore)

    # Creating the MPIManager object causes all needed worker and master nodes to be created
    manager = MPIManager(comm=comm,
                         data=data,
                         algo=algo,
                         model_builder=model_builder,
                         num_epochs=args.epochs,
                         train_list=train_list,
                         val_list=val_list,
                         num_masters=args.masters,
                         num_processes=args.processes,
                         synchronous=args.synchronous,
                         verbose=args.verbose,
                         monitor=args.monitor,
                         early_stopping=args.early_stopping,
Ejemplo n.º 2
0
                    loss=args.loss,
                    validate_every=validate_every,
                    mode='easgd',
                    sync_every=args.sync_every,
                    worker_optimizer=args.worker_optimizer,
                    elastic_force=args.elastic_force / (comm.Get_size() - 1),
                    elastic_lr=args.elastic_lr,
                    elastic_momentum=args.elastic_momentum)
    else:
        algo = Algo(args.optimizer,
                    loss=args.loss,
                    validate_every=validate_every,
                    sync_every=args.sync_every,
                    worker_optimizer=args.worker_optimizer)
    if args.load_algo:
        algo.load(args.load_algo)

    # Creating the MPIManager object causes all needed worker and master nodes to be created
    manager = MPIManager(comm=comm,
                         data=data,
                         algo=algo,
                         model_builder=model_builder,
                         num_epochs=args.epochs,
                         train_list=train_list,
                         val_list=val_list,
                         num_masters=args.masters,
                         synchronous=args.synchronous,
                         verbose=args.verbose,
                         monitor=args.monitor,
                         early_stopping=args.early_stopping,
                         target_metric=args.target_metric)