altcp_tls_mbedtls: fix TX when lower write returns ERR_MEM

This commit is contained in:
goldsimon 2017-03-30 14:55:37 +02:00
parent 51dbd1a7c0
commit 4313bf2a74

View File

@ -81,6 +81,8 @@
#include "mbedtls/memory_buffer_alloc.h" #include "mbedtls/memory_buffer_alloc.h"
#include "mbedtls/ssl_cache.h" #include "mbedtls/ssl_cache.h"
#include "mbedtls/ssl_internal.h" /* to call mbedtls_flush_output after ERR_MEM */
#include <string.h> #include <string.h>
#ifndef ALTCP_MBEDTLS_ENTROPY_PTR #ifndef ALTCP_MBEDTLS_ENTROPY_PTR
@ -283,6 +285,8 @@ altcp_mbedtls_lower_recv_process(struct altcp_pcb *conn, altcp_mbedtls_state_t *
} }
/* If we come here, handshake succeeded. */ /* If we come here, handshake succeeded. */
LWIP_ASSERT("rx pbufs left at end of handshake", state->rx == NULL); LWIP_ASSERT("rx pbufs left at end of handshake", state->rx == NULL);
LWIP_ASSERT("state", state->bio_bytes_read == 0);
LWIP_ASSERT("state", state->bio_bytes_appl == 0);
state->flags |= ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE; state->flags |= ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE;
/* issue "connect" callback" to upper connection (this can only happen for active open) */ /* issue "connect" callback" to upper connection (this can only happen for active open) */
if (conn->connected) { if (conn->connected) {
@ -429,17 +433,20 @@ static int
altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len) altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len)
{ {
struct altcp_pcb *conn = (struct altcp_pcb *)ctx; struct altcp_pcb *conn = (struct altcp_pcb *)ctx;
altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state; altcp_mbedtls_state_t *state;
struct pbuf* p; struct pbuf* p;
u16_t ret; u16_t ret;
u16_t copy_len; u16_t copy_len;
err_t err; err_t err;
if (state == NULL) { if ((conn == NULL) || (conn->state == NULL)) {
return 0; return MBEDTLS_ERR_NET_INVALID_CONTEXT;
} }
state = (altcp_mbedtls_state_t *)conn->state;
p = state->rx; p = state->rx;
/* @todo: return MBEDTLS_ERR_NET_CONN_RESET/MBEDTLS_ERR_NET_RECV_FAILED? */
if ((p == NULL) || ((p->len == 0) && (p->next == NULL))) { if ((p == NULL) || ((p->len == 0) && (p->next == NULL))) {
if (p) { if (p) {
pbuf_free(p); pbuf_free(p);
@ -489,6 +496,8 @@ altcp_mbedtls_lower_sent(void *arg, struct altcp_pcb *inner_conn, u16_t len)
/* @todo: do something here? */ /* @todo: do something here? */
return ERR_OK; return ERR_OK;
} }
/* try to send more if we failed before */
mbedtls_ssl_flush_output(&state->ssl_context);
/* call upper sent with len==0 if the application already sent data */ /* call upper sent with len==0 if the application already sent data */
if ((state->flags & ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT) && conn->sent) { if ((state->flags & ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT) && conn->sent) {
return conn->sent(conn->arg, conn, 0); return conn->sent(conn->arg, conn, 0);
@ -509,8 +518,10 @@ altcp_mbedtls_lower_poll(void *arg, struct altcp_pcb *inner_conn)
LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn); LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
/* check if there's unreceived rx data */ /* check if there's unreceived rx data */
if (conn->state) { if (conn->state) {
if (altcp_mbedtls_handle_rx_appldata(conn, (altcp_mbedtls_state_t *)conn->state) == altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
ERR_ABRT) { /* try to send more if we failed before */
mbedtls_ssl_flush_output(&state->ssl_context);
if (altcp_mbedtls_handle_rx_appldata(conn, state) == ERR_ABRT) {
return ERR_ABRT; return ERR_ABRT;
} }
} }
@ -855,19 +866,34 @@ altcp_mbedtls_write(struct altcp_pcb *conn, const void *dataptr, u16_t len, u8_t
return ERR_VAL; return ERR_VAL;
} }
/* HACK: if thre is something left to send, try to flush it and only
allow sending more if this succeeded (this is a hack because neither
returning 0 nor MBEDTLS_ERR_SSL_WANT_WRITE worked for me) */
if (state->ssl_context.out_left) {
mbedtls_ssl_flush_output(&state->ssl_context);
if (state->ssl_context.out_left) {
return ERR_MEM;
}
}
ret = mbedtls_ssl_write(&state->ssl_context, (const unsigned char *)dataptr, len); ret = mbedtls_ssl_write(&state->ssl_context, (const unsigned char *)dataptr, len);
/* try to send data... */ /* try to send data... */
altcp_output(conn->inner_conn); altcp_output(conn->inner_conn);
if(ret == len) { if (ret >= 0) {
state->flags |= ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT; if(ret == len) {
return ERR_OK; state->flags |= ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT;
} else if (ret <= 0) { return ERR_OK;
/* @todo: convert error to err_t */ } else {
return ERR_MEM; /* @todo/@fixme: assumption: either everything sent or error */
LWIP_ASSERT("ret <= 0", 0);
return ERR_MEM;
}
} else { } else {
/* assumption: either everything sent or error */ if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
LWIP_ASSERT("ret <= 0", 0); /* @todo: convert error to err_t */
return ERR_MEM; return ERR_MEM;
}
LWIP_ASSERT("unhandled error", 0);
return ERR_VAL;
} }
} }
@ -884,6 +910,9 @@ altcp_mbedtls_bio_send(void* ctx, const unsigned char* dataptr, size_t size)
u8_t apiflags = TCP_WRITE_FLAG_COPY; u8_t apiflags = TCP_WRITE_FLAG_COPY;
LWIP_ASSERT("conn != NULL", conn != NULL); LWIP_ASSERT("conn != NULL", conn != NULL);
if ((conn == NULL) || (conn->inner_conn == NULL)) {
return MBEDTLS_ERR_NET_INVALID_CONTEXT;
}
while (size_left) { while (size_left) {
u16_t write_len = (u16_t)LWIP_MIN(size_left, 0xFFFF); u16_t write_len = (u16_t)LWIP_MIN(size_left, 0xFFFF);
@ -891,9 +920,15 @@ altcp_mbedtls_bio_send(void* ctx, const unsigned char* dataptr, size_t size)
if (err == ERR_OK) { if (err == ERR_OK) {
written += write_len; written += write_len;
size_left -= write_len; size_left -= write_len;
} else if (err == ERR_MEM) {
if (written) {
return written;
}
return 0;//MBEDTLS_ERR_SSL_WANT_WRITE;
} else { } else {
LWIP_ASSERT("tls_write, tcp_write: ERR MEM", err == ERR_MEM ); LWIP_ASSERT("tls_write, tcp_write: err != ERR MEM", 0);
break; /* @todo: return MBEDTLS_ERR_NET_CONN_RESET or MBEDTLS_ERR_NET_SEND_FAILED */
return MBEDTLS_ERR_NET_SEND_FAILED;
} }
} }
return written; return written;