From 81191041c48c56c27ebcdae8c6b61aa48cd3419a Mon Sep 17 00:00:00 2001 From: Astro Date: Fri, 7 Sep 2018 00:12:00 +0200 Subject: [PATCH] improve style: flatten future --- src/client/mod.rs | 8 +- src/component/mod.rs | 4 +- src/happy_eyeballs.rs | 245 ++++++++++++++++++++++++------------------ 3 files changed, 144 insertions(+), 113 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 6c3226c2..cf4553a0 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -61,7 +61,7 @@ impl Client { done(idna::domain_to_ascii(&jid.domain)) .map_err(|_| Error::Idna) .and_then(|domain| - done(Connecter::from_lookup(&domain, "_xmpp-client._tcp", 5222)) + done(Connecter::from_lookup(&domain, Some("_xmpp-client._tcp"), 5222)) .map_err(Error::Connection) ) .and_then(|connecter| @@ -75,10 +75,8 @@ impl Client { } else { Err(Error::Protocol(ProtocolError::NoTls)) } - }).and_then(|starttls| - // TODO: flatten? - starttls - ).and_then(|tls_stream| + }).flatten() + .and_then(|tls_stream| XMPPStream::start(tls_stream, jid2, NS_JABBER_CLIENT.to_owned()) ).and_then(move |xmpp_stream| done(Self::auth(xmpp_stream, username, password)) diff --git a/src/component/mod.rs b/src/component/mod.rs index 088d20d5..e34ecd4d 100644 --- a/src/component/mod.rs +++ b/src/component/mod.rs @@ -53,8 +53,8 @@ impl Component { fn make_connect(jid: Jid, password: String, server: &str, port: u16) -> impl Future { let jid1 = jid.clone(); let password = password; - done(Connecter::from_lookup(server, "_xmpp-component._tcp", port)) - .and_then(|connecter| connecter) + done(Connecter::from_lookup(server, None, port)) + .flatten() .map_err(Error::Connection) .and_then(move |tcp_stream| { xmpp_stream::XMPPStream::start(tcp_stream, jid1, NS_JABBER_COMPONENT_ACCEPT.to_owned()) diff --git a/src/happy_eyeballs.rs b/src/happy_eyeballs.rs index e3bebe99..6b22cca0 100644 --- a/src/happy_eyeballs.rs +++ b/src/happy_eyeballs.rs @@ -1,44 +1,61 @@ use std::mem; -use std::net::{SocketAddr, IpAddr}; -use std::collections::{BTreeMap, btree_map}; +use std::net::SocketAddr; +use std::collections::BTreeMap; use std::collections::VecDeque; +use std::cell::RefCell; use futures::{Future, Poll, Async}; use tokio::net::{ConnectFuture, TcpStream}; use trust_dns_resolver::{IntoName, Name, ResolverFuture, error::ResolveError}; use trust_dns_resolver::lookup::SrvLookupFuture; use trust_dns_resolver::lookup_ip::LookupIpFuture; -use trust_dns_proto::rr::rdata::srv::SRV; use ConnecterError; +enum State { + AwaitResolver(Box + Send>), + ResolveSrv(ResolverFuture, SrvLookupFuture), + ResolveTarget(ResolverFuture, LookupIpFuture, u16), + Connecting(Option, Vec>), + Invalid, +} + pub struct Connecter { fallback_port: u16, - name: Name, + srv_domain: Option, domain: Name, - resolver_future: Box + Send>, - resolver_opt: Option, - srv_lookup_opt: Option, - srvs_opt: Option>, - ip_lookup_opt: Option<(u16, LookupIpFuture)>, - ips_opt: Option<(u16, VecDeque)>, - connect_opt: Option, + state: State, + targets: VecDeque<(Name, u16)>, } impl Connecter { - pub fn from_lookup(domain: &str, srv: &str, fallback_port: u16) -> Result { + pub fn from_lookup(domain: &str, srv: Option<&str>, fallback_port: u16) -> Result { + if let Ok(ip) = domain.parse() { + // use specified IP address, not domain name, skip the whole dns part + let connect = + RefCell::new(TcpStream::connect(&SocketAddr::new(ip, fallback_port))); + return Ok(Connecter { + fallback_port, + srv_domain: None, + domain: "nohost".into_name()?, + state: State::Connecting(None, vec![connect]), + targets: VecDeque::new(), + }); + } + let resolver_future = ResolverFuture::from_system_conf()?; - let name = format!("{}.{}.", srv, domain).into_name()?; + let state = State::AwaitResolver(resolver_future); + let srv_domain = match srv { + Some(srv) => + Some(format!("{}.{}.", srv, domain).into_name()?), + None => + None, + }; Ok(Connecter { fallback_port, - name, + srv_domain, domain: domain.into_name()?, - resolver_future, - resolver_opt: None, - srv_lookup_opt: None, - srvs_opt: None, - ip_lookup_opt: None, - ips_opt: None, - connect_opt: None, + state, + targets: VecDeque::new(), }) } } @@ -48,102 +65,118 @@ impl Future for Connecter { type Error = ConnecterError; fn poll(&mut self) -> Poll { - if self.resolver_opt.is_none() { - //println!("Poll resolver future"); - match self.resolver_future.poll()? { - Async::Ready(resolver) => - self.resolver_opt = Some(resolver), - Async::NotReady => - return Ok(Async::NotReady), - } - } - - if let Some(ref resolver) = self.resolver_opt { - if self.srvs_opt.is_none() { - if self.srv_lookup_opt.is_none() { - //println!("Lookup srv: {:?}", self.name); - self.srv_lookup_opt = Some(resolver.lookup_srv(&self.name)); - } - - if let Some(ref mut srv_lookup) = self.srv_lookup_opt { - match srv_lookup.poll() { - Ok(Async::Ready(t)) => { - let mut srvs = BTreeMap::new(); - for srv in t.iter() { - srvs.insert(srv.priority(), srv.clone()); + let state = mem::replace(&mut self.state, State::Invalid); + match state { + State::AwaitResolver(mut resolver_future) => { + match resolver_future.poll()? { + Async::NotReady => { + self.state = State::AwaitResolver(resolver_future); + Ok(Async::NotReady) + } + Async::Ready(resolver) => { + match &self.srv_domain { + &Some(ref srv_domain) => { + let srv_lookup = resolver.lookup_srv(srv_domain); + self.state = State::ResolveSrv(resolver, srv_lookup); + } + None => { + self.targets = + [(self.domain.clone(), self.fallback_port)].into_iter() + .cloned() + .collect(); + self.state = State::Connecting(Some(resolver), vec![]); } - srvs.insert(65535, SRV::new(65535, 0, self.fallback_port, self.domain.clone())); - self.srvs_opt = Some(srvs.into_iter()); } - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(_) => { - //println!("Ignore SVR error: {:?}", e); - let mut srvs = BTreeMap::new(); - srvs.insert(65535, SRV::new(65535, 0, self.fallback_port, self.domain.clone())); - self.srvs_opt = Some(srvs.into_iter()); - }, + self.poll() } } } - - if self.connect_opt.is_none() { - if self.ips_opt.is_none() { - if self.ip_lookup_opt.is_none() { - if let Some(ref mut srvs) = self.srvs_opt { - if let Some((_, srv)) = srvs.next() { - //println!("Lookup ip: {:?}", srv); - self.ip_lookup_opt = Some((srv.port(), resolver.lookup_ip(srv.target()))); - } else { - return Err(ConnecterError::NoSrv); - } - } - } - - if let Some((port, mut ip_lookup)) = mem::replace(&mut self.ip_lookup_opt, None) { - match ip_lookup.poll() { - Ok(Async::Ready(t)) => { - let mut ip_deque = VecDeque::new(); - ip_deque.extend(t.iter()); - //println!("IPs: {:?}", ip_deque); - self.ips_opt = Some((port, ip_deque)); - self.ip_lookup_opt = None; - }, - Ok(Async::NotReady) => { - self.ip_lookup_opt = Some((port, ip_lookup)); - return Ok(Async::NotReady) - }, - Err(_) => { - //println!("Ignore lookup error: {:?}", e); - self.ip_lookup_opt = None; - } - } - } - } - - if let Some((port, mut ip_deque)) = mem::replace(&mut self.ips_opt, None) { - if let Some(ip) = ip_deque.pop_front() { - //println!("Connect to {:?}:{}", ip, port); - self.connect_opt = Some(TcpStream::connect(&SocketAddr::new(ip, port))); - self.ips_opt = Some((port, ip_deque)); - } - } - } - - if let Some(mut connect_future) = mem::replace(&mut self.connect_opt, None) { - match connect_future.poll() { - Ok(Async::Ready(t)) => return Ok(Async::Ready(t)), + State::ResolveSrv(resolver, mut srv_lookup) => { + match srv_lookup.poll() { Ok(Async::NotReady) => { - self.connect_opt = Some(connect_future); - return Ok(Async::NotReady) + self.state = State::ResolveSrv(resolver, srv_lookup); + Ok(Async::NotReady) + } + Ok(Async::Ready(srv_result)) => { + let mut srv_map: BTreeMap<_, _> = + srv_result.iter() + .map(|srv| (srv.priority(), (srv.target().clone(), srv.port()))) + .collect(); + let targets = + srv_map.into_iter() + .map(|(_, tp)| tp) + .collect(); + self.targets = targets; + self.state = State::Connecting(Some(resolver), vec![]); + self.poll() } Err(_) => { - //println!("Ignore connect error: {:?}", e); - }, + // ignore, fallback + self.targets = + [(self.domain.clone(), self.fallback_port)].into_iter() + .cloned() + .collect(); + self.state = State::Connecting(Some(resolver), vec![]); + self.poll() + } } } + State::Connecting(resolver, mut connects) => { + if resolver.is_some() && + connects.len() == 0 && + self.targets.len() > 0 { + let resolver = resolver.unwrap(); + let (host, port) = self.targets.pop_front().unwrap(); + let ip_lookup = resolver.lookup_ip(host); + self.state = State::ResolveTarget(resolver, ip_lookup, port); + self.poll() + } else if connects.len() > 0 { + let mut success = None; + connects.retain(|connect| { + match connect.borrow_mut().poll() { + Ok(Async::NotReady) => true, + Ok(Async::Ready(connection)) => { + success = Some(connection); + false + } + Err(_) => false, + } + }); + match success { + Some(connection) => + Ok(Async::Ready(connection)), + None => { + self.state = State::Connecting(resolver, connects); + Ok(Async::NotReady) + }, + } + } else { + Err(ConnecterError::AllFailed) + } + } + State::ResolveTarget(resolver, mut ip_lookup, port) => { + match ip_lookup.poll() { + Ok(Async::NotReady) => { + self.state = State::ResolveTarget(resolver, ip_lookup, port); + Ok(Async::NotReady) + } + Ok(Async::Ready(ip_result)) => { + let connects = + ip_result.iter() + .map(|ip| RefCell::new(TcpStream::connect(&SocketAddr::new(ip, port)))) + .collect(); + self.state = State::Connecting(Some(resolver), connects); + self.poll() + } + Err(_) => { + // ignore, next… + self.state = State::Connecting(Some(resolver), vec![]); + self.poll() + } + } + } + _ => panic!("") } - - Ok(Async::NotReady) } }